AutoGPTQ/auto_gptq/nn_modules/qlinear/__init__.py
2023-08-11 14:52:26 +08:00

75 lines
3.2 KiB
Python

import torch.nn as nn
class GeneralQuantLinear(nn.Linear):
def __init__(self, quant_linear_module):
super().__init__(
in_features=quant_linear_module.infeatures,
out_features=quant_linear_module.outfeatures,
bias=quant_linear_module.bias is not None
)
self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures
self.bits = quant_linear_module.bits
self.group_size = quant_linear_module.group_size
self.maxq = quant_linear_module.maxq
self.weight.requires_grad = False
self.weight.data = quant_linear_module.qweight
self.register_buffer('qweight', quant_linear_module.qweight)
if quant_linear_module.bias is not None:
self.bias.data = quant_linear_module.bias
self.register_buffer('qzeros', quant_linear_module.qzeros)
self.register_buffer('scales', quant_linear_module.scales)
self.register_buffer('g_idx', quant_linear_module.g_idx)
# arg of qlinear_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "kernel_switch_threshold"):
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda"):
self.autogptq_cuda = quant_linear_module.autogptq_cuda
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "half_indim"):
self.half_indim = quant_linear_module.half_indim
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "use_cuda_fp16"):
self.use_cuda_fp16 = quant_linear_module.use_cuda_fp16
# args of qlinear_exllama
if hasattr(quant_linear_module, "_use_act_order"):
self._use_act_order = quant_linear_module._use_act_order
# arg of qlinaer_exllama
if hasattr(quant_linear_module, "width"):
self.width = quant_linear_module.width
# arg of qlinear_exllama
if hasattr(quant_linear_module, "q4"):
self.q4 = quant_linear_module.q4
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward
@classmethod
def convert_to_torch_linear(cls, model: nn.Module, target_module_type: "QuantLinear"):
for name, m in model.named_modules():
if not isinstance(m, target_module_type):
continue
new_m = cls(m)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, new_m)