24 lines
633 B
Python
24 lines
633 B
Python
from abc import abstractmethod
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class FusedBaseModule(nn.Module):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(cls, *args, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class FusedBaseAttentionModule(FusedBaseModule):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class FusedBaseMLPModule(FusedBaseModule):
|
|
@classmethod
|
|
@abstractmethod
|
|
def inject_to_model(cls, model, use_triton=False, **kwargs):
|
|
raise NotImplementedError()
|