from abc import abstractmethod from logging import getLogger import torch.nn as nn from .triton_utils.mixin import TritonModuleMixin logger = getLogger(__name__) class FusedBaseModule(nn.Module, TritonModuleMixin): @classmethod @abstractmethod def inject_to_model(cls, *args, **kwargs): raise NotImplementedError() class FusedBaseAttentionModule(FusedBaseModule): @classmethod @abstractmethod def inject_to_model( cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, trainable=False, **kwargs ): raise NotImplementedError() @classmethod def warmup(cls, model, transpose=False, seqlen=2048): pass class FusedBaseMLPModule(FusedBaseModule): @classmethod @abstractmethod def inject_to_model(cls, model, use_triton=False, **kwargs): raise NotImplementedError()