from functools import partial from typing import Union, Callable, Optional import torch import torch.nn as nn from functorch.compile import memory_efficient_fusion from torch.nn import functional as F def act_dropout( hidden_states: torch.Tensor, activation: Union[Callable, nn.Module], dropout: float = 0.0 ): hidden_states = activation(hidden_states) return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout) def dropout_res( hidden_states: torch.Tensor, residual: torch.Tensor, dropout: float = 0.0 ): hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout) return torch.add(hidden_states, residual) def act_dropout_res( hidden_states: torch.Tensor, residual: torch.Tensor, activation: Union[Callable, nn.Module], dropout: float = 0.0 ): hidden_states = activation(hidden_states) hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout) return torch.add(hidden_states, residual) class NVFusedActDropoutRes(nn.Module): def __init__( self, activation: Optional[Union[Callable, nn.Module]] = None, dropout: float = 0.0, residual: bool = False, is_cuda: bool = False ): super(NVFusedActDropoutRes, self).__init__() fn = partial(F.dropout, p=dropout) if activation is not None and residual: fn = partial(act_dropout_res, activation=activation, dropout=dropout) elif activation is not None: fn = partial(act_dropout, activation=activation, dropout=dropout) elif residual: fn = partial(dropout_res, dropout=dropout) self.fn = fn if is_cuda: self.fn = memory_efficient_fusion(self.fn) self.residual = residual def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None): if self.residual: return self.fn(hidden_states, residual) else: return self.fn(hidden_states) class FusedMLP(nn.Module): def __init__( self, input_proj: nn.Linear, out_proj: nn.Linear, activation: Optional[Union[Callable, nn.Module]] = None, in_dropout: float = 0.0, out_dropout: float = 0.0, training: bool = False, residual: bool = False ): super(FusedMLP, self).__init__() if activation is None: activation = nn.Identity() is_cuda = input_proj.weight.data.device.type == "cuda" self.input_proj = input_proj self.fused_op1 = NVFusedActDropoutRes( activation=activation, dropout=in_dropout if training else 0.0, residual=False, is_cuda=is_cuda ) self.out_proj = out_proj self.fused_op2 = NVFusedActDropoutRes( activation=None, dropout=out_dropout if training else 0.0, residual=residual, is_cuda=is_cuda ) def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None): return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual) def gated_act_dropout( gate_states: torch.Tensor, up_states: torch.Tensor, activation: Union[Callable, nn.Module], dropout: float = 0.0 ): hidden_states = activation(gate_states) * up_states return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout) class NVFusedGatedActDropout(nn.Module): def __init__( self, activation: Optional[Union[Callable, nn.Module]] = None, dropout: float = 0.0, is_cuda: bool = False ): super(NVFusedGatedActDropout, self).__init__() fn = partial(F.dropout, p=dropout) if activation is not None: fn = partial(gated_act_dropout, activation=activation, dropout=dropout) self.fn = fn if is_cuda: self.fn = memory_efficient_fusion(self.fn) def forward(self, gate_states: torch.Tensor, up_states): if isinstance(self.fn, nn.Dropout): return self.fn(gate_states * up_states) return self.fn(gate_states, up_states) class FusedGatedMLP(nn.Module): def __init__( self, input_proj: nn.Linear, out_proj: nn.Linear, activation: Optional[Union[Callable, nn.Module]] = None, in_dropout: float = 0.0, out_dropout: float = 0.0, training: bool = False ): super(FusedGatedMLP, self).__init__() if activation is None: activation = nn.Identity() self.input_proj = input_proj self.fused_op = NVFusedGatedActDropout( activation=activation, dropout=in_dropout if training else 0.0, is_cuda=input_proj.weight.data.device.type == "cuda" ) self.out_proj = out_proj self.out_dropout = nn.Dropout(out_dropout) self.intermediate_size = self.input_proj.out_features // 2 def forward(self, hidden_states: torch.Tensor): hidden_states = self.input_proj(hidden_states) return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))