42 lines
939 B
Python
42 lines
939 B
Python
from abc import abstractmethod
|
|
from logging import getLogger
|
|
|
|
import torch.nn as nn
|
|
from .triton_utils.mixin import TritonModuleMixin
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusedBaseModule(nn.Module, TritonModuleMixin):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(cls, *args, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class FusedBaseAttentionModule(FusedBaseModule):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(
|
|
cls,
|
|
model,
|
|
use_triton=False,
|
|
group_size=-1,
|
|
use_cuda_fp16=True,
|
|
desc_act=False,
|
|
trainable=False,
|
|
**kwargs
|
|
):
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def warmup(cls, model, transpose=False, seqlen=2048):
|
|
pass
|
|
|
|
|
|
class FusedBaseMLPModule(FusedBaseModule):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(cls, model, use_triton=False, **kwargs):
|
|
raise NotImplementedError()
|