168 lines
5.1 KiB
Python
168 lines
5.1 KiB
Python
from functools import partial
|
|
from typing import Union, Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from functorch.compile import memory_efficient_fusion
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def act_dropout(
|
|
hidden_states: torch.Tensor,
|
|
activation: Union[Callable, nn.Module],
|
|
dropout: float = 0.0
|
|
):
|
|
hidden_states = activation(hidden_states)
|
|
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
|
|
|
|
|
def dropout_res(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
dropout: float = 0.0
|
|
):
|
|
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
|
return torch.add(hidden_states, residual)
|
|
|
|
|
|
def act_dropout_res(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
activation: Union[Callable, nn.Module],
|
|
dropout: float = 0.0
|
|
):
|
|
hidden_states = activation(hidden_states)
|
|
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
|
return torch.add(hidden_states, residual)
|
|
|
|
|
|
class NVFusedActDropoutRes(nn.Module):
|
|
def __init__(
|
|
self,
|
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
|
dropout: float = 0.0,
|
|
residual: bool = False,
|
|
is_cuda: bool = False
|
|
):
|
|
super(NVFusedActDropoutRes, self).__init__()
|
|
|
|
fn = partial(F.dropout, p=dropout)
|
|
if activation is not None and residual:
|
|
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
|
|
elif activation is not None:
|
|
fn = partial(act_dropout, activation=activation, dropout=dropout)
|
|
elif residual:
|
|
fn = partial(dropout_res, dropout=dropout)
|
|
|
|
self.fn = fn
|
|
if is_cuda:
|
|
self.fn = memory_efficient_fusion(self.fn)
|
|
|
|
self.residual = residual
|
|
|
|
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
|
if self.residual:
|
|
return self.fn(hidden_states, residual)
|
|
else:
|
|
return self.fn(hidden_states)
|
|
|
|
|
|
class FusedMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_proj: nn.Linear,
|
|
out_proj: nn.Linear,
|
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
|
in_dropout: float = 0.0,
|
|
out_dropout: float = 0.0,
|
|
training: bool = False,
|
|
residual: bool = False
|
|
):
|
|
super(FusedMLP, self).__init__()
|
|
|
|
if activation is None:
|
|
activation = nn.Identity()
|
|
|
|
is_cuda = input_proj.weight.data.device.type == "cuda"
|
|
|
|
self.input_proj = input_proj
|
|
self.fused_op1 = NVFusedActDropoutRes(
|
|
activation=activation,
|
|
dropout=in_dropout if training else 0.0,
|
|
residual=False,
|
|
is_cuda=is_cuda
|
|
)
|
|
self.out_proj = out_proj
|
|
self.fused_op2 = NVFusedActDropoutRes(
|
|
activation=None,
|
|
dropout=out_dropout if training else 0.0,
|
|
residual=residual,
|
|
is_cuda=is_cuda
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
|
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
|
|
|
|
|
|
def gated_act_dropout(
|
|
gate_states: torch.Tensor,
|
|
up_states: torch.Tensor,
|
|
activation: Union[Callable, nn.Module],
|
|
dropout: float = 0.0
|
|
):
|
|
hidden_states = activation(gate_states) * up_states
|
|
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
|
|
|
|
|
class NVFusedGatedActDropout(nn.Module):
|
|
def __init__(
|
|
self,
|
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
|
dropout: float = 0.0,
|
|
is_cuda: bool = False
|
|
):
|
|
super(NVFusedGatedActDropout, self).__init__()
|
|
|
|
fn = partial(F.dropout, p=dropout)
|
|
if activation is not None:
|
|
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
|
|
|
|
self.fn = fn
|
|
if is_cuda:
|
|
self.fn = memory_efficient_fusion(self.fn)
|
|
|
|
def forward(self, gate_states: torch.Tensor, up_states):
|
|
if isinstance(self.fn, nn.Dropout):
|
|
return self.fn(gate_states * up_states)
|
|
return self.fn(gate_states, up_states)
|
|
|
|
|
|
class FusedGatedMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_proj: nn.Linear,
|
|
out_proj: nn.Linear,
|
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
|
in_dropout: float = 0.0,
|
|
out_dropout: float = 0.0,
|
|
training: bool = False
|
|
):
|
|
super(FusedGatedMLP, self).__init__()
|
|
|
|
if activation is None:
|
|
activation = nn.Identity()
|
|
|
|
self.input_proj = input_proj
|
|
self.fused_op = NVFusedGatedActDropout(
|
|
activation=activation,
|
|
dropout=in_dropout if training else 0.0,
|
|
is_cuda=input_proj.weight.data.device.type == "cuda"
|
|
)
|
|
self.out_proj = out_proj
|
|
self.out_dropout = nn.Dropout(out_dropout)
|
|
|
|
self.intermediate_size = self.input_proj.out_features // 2
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
hidden_states = self.input_proj(hidden_states)
|
|
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))
|