AutoGPTQ/auto_gptq/nn_modules/_fused_base.py
2023-05-12 18:09:23 +08:00

24 lines
633 B
Python

from abc import abstractmethod
import torch.nn as nn
class FusedBaseModule(nn.Module):
@classmethod
@abstractmethod
def inject_to_model(cls, *args, **kwargs):
raise NotImplementedError()
class FusedBaseAttentionModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
raise NotImplementedError()
class FusedBaseMLPModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
raise NotImplementedError()