AutoGPTQ/auto_gptq/nn_modules/_fused_base.py
2023-05-14 11:49:10 +08:00

37 lines
927 B
Python

from abc import abstractmethod
from logging import getLogger
import torch.nn as nn
logger = getLogger(__name__)
try:
from .triton_utils import TritonModuleMixin
except ImportError:
logger.error('triton not installed.')
raise
class FusedBaseModule(nn.Module, TritonModuleMixin):
@classmethod
@abstractmethod
def inject_to_model(cls, *args, **kwargs):
raise NotImplementedError()
class FusedBaseAttentionModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
raise NotImplementedError()
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
pass
class FusedBaseMLPModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
raise NotImplementedError()