AutoGPTQ/auto_gptq/nn_modules/fused_modules/mlp.py
2023-08-06 17:23:57 +08:00

168 lines
5.1 KiB
Python

from functools import partial
from typing import Union, Callable, Optional
import torch
import torch.nn as nn
from functorch.compile import memory_efficient_fusion
from torch.nn import functional as F
def act_dropout(
hidden_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
def dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
dropout: float = 0.0
):
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
def act_dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
class NVFusedActDropoutRes(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
residual: bool = False,
is_cuda: bool = False
):
super(NVFusedActDropoutRes, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None and residual:
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
elif activation is not None:
fn = partial(act_dropout, activation=activation, dropout=dropout)
elif residual:
fn = partial(dropout_res, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
self.residual = residual
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
if self.residual:
return self.fn(hidden_states, residual)
else:
return self.fn(hidden_states)
class FusedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False,
residual: bool = False
):
super(FusedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
is_cuda = input_proj.weight.data.device.type == "cuda"
self.input_proj = input_proj
self.fused_op1 = NVFusedActDropoutRes(
activation=activation,
dropout=in_dropout if training else 0.0,
residual=False,
is_cuda=is_cuda
)
self.out_proj = out_proj
self.fused_op2 = NVFusedActDropoutRes(
activation=None,
dropout=out_dropout if training else 0.0,
residual=residual,
is_cuda=is_cuda
)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
def gated_act_dropout(
gate_states: torch.Tensor,
up_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(gate_states) * up_states
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
class NVFusedGatedActDropout(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
is_cuda: bool = False
):
super(NVFusedGatedActDropout, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None:
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
def forward(self, gate_states: torch.Tensor, up_states):
if isinstance(self.fn, nn.Dropout):
return self.fn(gate_states * up_states)
return self.fn(gate_states, up_states)
class FusedGatedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False
):
super(FusedGatedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
self.input_proj = input_proj
self.fused_op = NVFusedGatedActDropout(
activation=activation,
dropout=in_dropout if training else 0.0,
is_cuda=input_proj.weight.data.device.type == "cuda"
)
self.out_proj = out_proj
self.out_dropout = nn.Dropout(out_dropout)
self.intermediate_size = self.input_proj.out_features // 2
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.input_proj(hidden_states)
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))