66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import torch
|
|
|
|
from ..qlinear import GeneralQuantLinear
|
|
from ..qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
|
|
from ..qlinear.qlinear_cuda_old import QuantLinear as OldCudaQuantLinear
|
|
from ..qlinear.qlinear_triton import QuantLinear as TritonQuantLinear
|
|
|
|
|
|
class FusedGeneralQuantLinear(GeneralQuantLinear):
|
|
def __init__(self, quant_linear_module):
|
|
super(FusedGeneralQuantLinear, self).__init__(quant_linear_module)
|
|
|
|
@classmethod
|
|
def fuse(
|
|
cls,
|
|
q_proj,
|
|
k_proj=None,
|
|
v_proj=None,
|
|
):
|
|
qweights, qzeros, scales, g_idx, bias = [], [], [], [], []
|
|
outfeatures = 0
|
|
for module in [q_proj, k_proj, v_proj]:
|
|
if module is not None:
|
|
qweights.append(module.qweight)
|
|
qzeros.append(module.qzeros)
|
|
scales.append(module.scales)
|
|
g_idx.append(module.g_idx)
|
|
bias.append(module.bias)
|
|
outfeatures += module.outfeatures
|
|
|
|
if len(qweights) > 1:
|
|
qweights = torch.cat(qweights, dim=1)
|
|
qzeros = torch.cat(qzeros, dim=1)
|
|
scales = torch.cat(scales, dim=1)
|
|
g_idx = torch.cat(g_idx, dim=0)
|
|
if bias[0] is not None:
|
|
bias = torch.cat(bias, dim=0)
|
|
|
|
qlinear_args = (
|
|
q_proj.bits,
|
|
q_proj.group_size,
|
|
q_proj.infeatures,
|
|
outfeatures,
|
|
bias is not None
|
|
)
|
|
qlinear_kwargs = {"trainable": q_proj.trainable}
|
|
if not isinstance(q_proj, TritonQuantLinear):
|
|
qlinear_kwargs["kernel_switch_threshold"] = q_proj.kernel_switch_threshold
|
|
if isinstance(q_proj, OldCudaQuantLinear):
|
|
qlinear_kwargs["use_cuda_fp16"] = q_proj.use_cuda_fp16
|
|
QuantLinear = OldCudaQuantLinear
|
|
else:
|
|
QuantLinear = CudaQuantLinear
|
|
else:
|
|
QuantLinear = TritonQuantLinear
|
|
fused_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
|
|
|
fused_proj.qweight = qweights
|
|
fused_proj.qzeros = qzeros
|
|
fused_proj.scales = scales
|
|
fused_proj.g_idx = g_idx
|
|
fused_proj.bais = bias
|
|
|
|
del q_proj, k_proj, v_proj
|
|
|
|
return cls(fused_proj)
|