56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
import torch.nn as nn
|
|
|
|
|
|
class GeneralQuantLinear(nn.Linear):
|
|
def __init__(self, quant_linear_module):
|
|
super().__init__(
|
|
in_features=quant_linear_module.infeatures,
|
|
out_features=quant_linear_module.outfeatures,
|
|
bias=True
|
|
)
|
|
self.infeatures = quant_linear_module.infeatures
|
|
self.outfeatures = quant_linear_module.outfeatures
|
|
self.bits = quant_linear_module.bits
|
|
self.group_size = quant_linear_module.group_size
|
|
self.maxq = quant_linear_module.maxq
|
|
|
|
self.weight.requires_grad = False
|
|
|
|
self.weight.data = quant_linear_module.qweight
|
|
self.register_buffer('qweight', quant_linear_module.qweight)
|
|
self.bias.data = quant_linear_module.bias
|
|
|
|
self.qweight.requires_grad = False
|
|
self.bias.requires_grad = False
|
|
|
|
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
|
self.register_buffer('scales', quant_linear_module.scales)
|
|
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
|
|
|
if hasattr(quant_linear_module, "wf"):
|
|
self.wf = quant_linear_module.wf
|
|
if hasattr(quant_linear_module, "kernel_switch_threshold"):
|
|
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
|
|
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
|
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
|
|
|
self.trainable = quant_linear_module.trainable
|
|
|
|
self.forward = quant_linear_module.forward
|
|
|
|
@classmethod
|
|
def inject_to_model(cls, model, target_module_type):
|
|
for name, m in model.named_modules():
|
|
if not isinstance(m, target_module_type):
|
|
continue
|
|
new_m = cls(m)
|
|
if '.' in name:
|
|
parent_name = name.rsplit('.', 1)[0]
|
|
child_name = name[len(parent_name) + 1:]
|
|
parent = model.get_submodule(parent_name)
|
|
else:
|
|
parent_name = ''
|
|
parent = model
|
|
child_name = name
|
|
|
|
setattr(parent, child_name, new_m)
|