add trainable mode
This commit is contained in:
parent
fe5f5d12ed
commit
2b532f9453
12 changed files with 169 additions and 22 deletions
|
@ -23,7 +23,7 @@ from ..nn_modules.qlinear import GeneralQuantLinear
|
||||||
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
||||||
from ..quantization import GPTQ
|
from ..quantization import GPTQ
|
||||||
from ..utils.data_utils import collate_data
|
from ..utils.data_utils import collate_data
|
||||||
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE
|
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -77,7 +77,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
||||||
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
||||||
|
|
||||||
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
quantized: bool,
|
||||||
|
quantize_config: BaseQuantizeConfig,
|
||||||
|
is_triton_backend: bool = False,
|
||||||
|
injected_fused_attention: bool = False,
|
||||||
|
injected_fused_mlp: bool = False,
|
||||||
|
trainable: bool = False
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -86,6 +95,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
self.quantize_config = quantize_config
|
self.quantize_config = quantize_config
|
||||||
self.config = self.model.config
|
self.config = self.model.config
|
||||||
|
|
||||||
|
self.is_triton_backend = is_triton_backend
|
||||||
|
self.injected_fused_attention = injected_fused_attention
|
||||||
|
self.injected_fused_mlp = injected_fused_mlp
|
||||||
|
self.trainable = trainable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quantized(self):
|
def quantized(self):
|
||||||
return self._quantized
|
return self._quantized
|
||||||
|
@ -510,6 +524,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
|
trainable: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""load quantized model from local disk"""
|
"""load quantized model from local disk"""
|
||||||
|
@ -517,7 +532,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
logger.warning("triton is not installed, reset use_triton to False")
|
logger.warning("triton is not installed, reset use_triton to False")
|
||||||
use_triton = False
|
use_triton = False
|
||||||
|
|
||||||
# == step1: prepare configs and file names == #
|
# == step1: prepare configs and file names, and check values of arguments passed in == #
|
||||||
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
|
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
|
||||||
if config.model_type not in SUPPORTED_MODELS:
|
if config.model_type not in SUPPORTED_MODELS:
|
||||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||||
|
@ -543,6 +558,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Could not find model at {model_save_name}")
|
raise FileNotFoundError(f"Could not find model at {model_save_name}")
|
||||||
|
|
||||||
|
if not use_triton and trainable:
|
||||||
|
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||||
|
|
||||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||||
def skip(*args, **kwargs):
|
def skip(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -578,7 +596,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
quantize_config.group_size,
|
quantize_config.group_size,
|
||||||
use_triton=use_triton,
|
use_triton=use_triton,
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
desc_act=quantize_config.desc_act
|
desc_act=quantize_config.desc_act,
|
||||||
|
trainable=trainable
|
||||||
)
|
)
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
|
@ -643,6 +662,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
# == step5: (optional) inject optimized module == #
|
# == step5: (optional) inject optimized module == #
|
||||||
if inject_fused_attention:
|
if inject_fused_attention:
|
||||||
if cls.fused_attn_module_type is None:
|
if cls.fused_attn_module_type is None:
|
||||||
|
inject_fused_attention = False
|
||||||
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
||||||
else:
|
else:
|
||||||
cls.fused_attn_module_type.inject_to_model(
|
cls.fused_attn_module_type.inject_to_model(
|
||||||
|
@ -650,10 +670,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
use_triton=use_triton,
|
use_triton=use_triton,
|
||||||
group_size=quantize_config.group_size,
|
group_size=quantize_config.group_size,
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
desc_act=quantize_config.desc_act
|
desc_act=quantize_config.desc_act,
|
||||||
|
trainable=trainable
|
||||||
)
|
)
|
||||||
if inject_fused_mlp:
|
if inject_fused_mlp:
|
||||||
if cls.fused_mlp_module_type is None:
|
if cls.fused_mlp_module_type is None:
|
||||||
|
inject_fused_mlp = False
|
||||||
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
||||||
else:
|
else:
|
||||||
cls.fused_mlp_module_type.inject_to_model(
|
cls.fused_mlp_module_type.inject_to_model(
|
||||||
|
@ -670,7 +692,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
||||||
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
||||||
|
|
||||||
return cls(model, True, quantize_config)
|
return cls(
|
||||||
|
model,
|
||||||
|
True,
|
||||||
|
quantize_config,
|
||||||
|
is_triton_backend=use_triton,
|
||||||
|
injected_fused_attention=inject_fused_attention,
|
||||||
|
injected_fused_mlp=inject_fused_mlp and use_triton,
|
||||||
|
trainable=trainable
|
||||||
|
)
|
||||||
|
|
||||||
def warmup_triton(self, enabled: bool = True):
|
def warmup_triton(self, enabled: bool = True):
|
||||||
if not enabled:
|
if not enabled:
|
||||||
|
@ -685,6 +715,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
if self.fused_mlp_module_type is not None:
|
if self.fused_mlp_module_type is not None:
|
||||||
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
||||||
|
|
||||||
|
def enable_trainable_mode(self, enabled: bool = True):
|
||||||
|
if not self.is_triton_backend and enabled:
|
||||||
|
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||||
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "trainable"):
|
||||||
|
setattr(m, "trainable", enabled)
|
||||||
|
|
||||||
|
def disable_trainable_mode(self):
|
||||||
|
self.enable_trainable_mode(enabled=False)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
try:
|
try:
|
||||||
return super().__getattr__(item)
|
return super().__getattr__(item)
|
||||||
|
|
|
@ -50,7 +50,17 @@ def get_module_by_name_suffix(model, module_name: str):
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def make_quant(module, names, bits, group_size, name='', use_triton=False, use_cuda_fp16=True, desc_act=False):
|
def make_quant(
|
||||||
|
module,
|
||||||
|
names,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
name='',
|
||||||
|
use_triton=False,
|
||||||
|
use_cuda_fp16=True,
|
||||||
|
desc_act=False,
|
||||||
|
trainable=False
|
||||||
|
):
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
||||||
|
|
||||||
if isinstance(module, QuantLinear):
|
if isinstance(module, QuantLinear):
|
||||||
|
@ -71,13 +81,25 @@ def make_quant(module, names, bits, group_size, name='', use_triton=False, use_c
|
||||||
in_features = tmp.weight.shape[0]
|
in_features = tmp.weight.shape[0]
|
||||||
out_features = tmp.weight.shape[1]
|
out_features = tmp.weight.shape[1]
|
||||||
if (not(desc_act) or group_size == -1) and not use_triton:
|
if (not(desc_act) or group_size == -1) and not use_triton:
|
||||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16)
|
new_layer = QuantLinear(
|
||||||
|
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True)
|
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
|
||||||
new_layer.device = ori_layer_device
|
new_layer.device = ori_layer_device
|
||||||
setattr(module, attr, new_layer.to(ori_layer_device))
|
setattr(module, attr, new_layer.to(ori_layer_device))
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
make_quant(child, names, bits, group_size, name + '.' + name1 if name != '' else name1, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16,desc_act=desc_act)
|
make_quant(
|
||||||
|
child,
|
||||||
|
names,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
name + '.' + name1 if name != '' else name1,
|
||||||
|
use_triton=use_triton,
|
||||||
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
|
desc_act=desc_act,
|
||||||
|
trainable=trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pack_model(
|
def pack_model(
|
||||||
|
|
|
@ -65,6 +65,7 @@ class AutoGPTQForCausalLM:
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
|
trainable: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(save_dir, trust_remote_code)
|
model_type = check_and_get_model_type(save_dir, trust_remote_code)
|
||||||
|
@ -85,6 +86,7 @@ class AutoGPTQForCausalLM:
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
warmup_triton=warmup_triton,
|
warmup_triton=warmup_triton,
|
||||||
|
trainable=trainable,
|
||||||
**keywords
|
**keywords
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,16 @@ class FusedBaseModule(nn.Module, TritonModuleMixin):
|
||||||
class FusedBaseAttentionModule(FusedBaseModule):
|
class FusedBaseAttentionModule(FusedBaseModule):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
def inject_to_model(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
use_triton=False,
|
||||||
|
group_size=-1,
|
||||||
|
use_cuda_fp16=True,
|
||||||
|
desc_act=False,
|
||||||
|
trainable=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -226,7 +226,16 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
return outputs # a, present, (attentions)
|
return outputs # a, present, (attentions)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
def inject_to_model(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
use_triton=False,
|
||||||
|
group_size=-1,
|
||||||
|
use_cuda_fp16=True,
|
||||||
|
desc_act=False,
|
||||||
|
trainable=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
||||||
|
|
||||||
|
@ -253,7 +262,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||||
True if q_proj.bias is not None else False,
|
True if q_proj.bias is not None else False,
|
||||||
)
|
)
|
||||||
qlinear_kwargs = dict()
|
qlinear_kwargs = {"trainable": trainable}
|
||||||
if (not desc_act or group_size == -1) and not use_triton:
|
if (not desc_act or group_size == -1) and not use_triton:
|
||||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||||
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||||
|
|
|
@ -126,7 +126,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
def inject_to_model(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
use_triton=False,
|
||||||
|
group_size=-1,
|
||||||
|
use_cuda_fp16=True,
|
||||||
|
desc_act=False,
|
||||||
|
trainable=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||||
"""
|
"""
|
||||||
|
@ -153,7 +162,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||||
True if q_proj.bias is not None else False,
|
True if q_proj.bias is not None else False,
|
||||||
)
|
)
|
||||||
qlinear_kwargs = dict()
|
qlinear_kwargs = {"trainable": trainable}
|
||||||
if (not desc_act or group_size == -1) and not use_triton:
|
if (not desc_act or group_size == -1) and not use_triton:
|
||||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||||
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||||
|
|
|
@ -35,6 +35,8 @@ class GeneralQuantLinear(nn.Linear):
|
||||||
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
||||||
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
||||||
|
|
||||||
|
self.trainable = quant_linear_module.trainable
|
||||||
|
|
||||||
self.forward = quant_linear_module.forward
|
self.forward = quant_linear_module.forward
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -26,10 +26,13 @@ class QuantLinear(nn.Module):
|
||||||
outfeatures,
|
outfeatures,
|
||||||
bias,
|
bias,
|
||||||
kernel_switch_threshold=128,
|
kernel_switch_threshold=128,
|
||||||
|
trainable=False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 3, 4, 8]:
|
if bits not in [2, 3, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||||
|
if trainable:
|
||||||
|
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
|
||||||
|
|
||||||
self.infeatures = infeatures
|
self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures
|
self.outfeatures = outfeatures
|
||||||
|
@ -76,6 +79,8 @@ class QuantLinear(nn.Module):
|
||||||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||||
self.autogptq_cuda_available = False
|
self.autogptq_cuda_available = False
|
||||||
|
|
||||||
|
self.trainable = trainable
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx=None):
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
W = linear.weight.data.clone()
|
W = linear.weight.data.clone()
|
||||||
if isinstance(linear, nn.Conv2d):
|
if isinstance(linear, nn.Conv2d):
|
||||||
|
|
|
@ -16,6 +16,7 @@ except ImportError:
|
||||||
logger.warning('CUDA extension not installed.')
|
logger.warning('CUDA extension not installed.')
|
||||||
_autogptq_cuda_available = False
|
_autogptq_cuda_available = False
|
||||||
|
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
class QuantLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -25,12 +26,15 @@ class QuantLinear(nn.Module):
|
||||||
outfeatures,
|
outfeatures,
|
||||||
bias,
|
bias,
|
||||||
use_cuda_fp16=True,
|
use_cuda_fp16=True,
|
||||||
kernel_switch_threshold=128
|
kernel_switch_threshold=128,
|
||||||
|
trainable=False
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 3, 4, 8]:
|
if bits not in [2, 3, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||||
|
if trainable:
|
||||||
|
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
|
||||||
self.infeatures = infeatures
|
self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures
|
self.outfeatures = outfeatures
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
@ -80,6 +84,8 @@ class QuantLinear(nn.Module):
|
||||||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||||
self.autogptq_cuda_available = False
|
self.autogptq_cuda_available = False
|
||||||
|
|
||||||
|
self.trainable = trainable
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx):
|
def pack(self, linear, scales, zeros, g_idx):
|
||||||
scales = scales.t().contiguous()
|
scales = scales.t().contiguous()
|
||||||
zeros = zeros.t().contiguous()
|
zeros = zeros.t().contiguous()
|
||||||
|
|
|
@ -11,7 +11,10 @@ from ..triton_utils.mixin import TritonModuleMixin
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ..triton_utils.kernels import quant_matmul_248, transpose_quant_matmul_248, QuantLinearFunction
|
from ..triton_utils.kernels import (
|
||||||
|
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
|
||||||
|
QuantLinearFunction, QuantLinearInferenceOnlyFunction
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error('triton not installed.')
|
logger.error('triton not installed.')
|
||||||
raise
|
raise
|
||||||
|
@ -24,7 +27,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
||||||
group_size,
|
group_size,
|
||||||
infeatures,
|
infeatures,
|
||||||
outfeatures,
|
outfeatures,
|
||||||
bias
|
bias,
|
||||||
|
trainable=False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 4, 8]:
|
if bits not in [2, 4, 8]:
|
||||||
|
@ -58,6 +62,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
self.trainable = trainable
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx=None):
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
W = linear.weight.data.clone()
|
W = linear.weight.data.clone()
|
||||||
if isinstance(linear, nn.Conv2d):
|
if isinstance(linear, nn.Conv2d):
|
||||||
|
@ -122,7 +128,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
out = QuantLinearFunction.apply(
|
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
|
||||||
|
out = quant_linear_fn.apply(
|
||||||
x.reshape(-1, x.shape[-1]),
|
x.reshape(-1, x.shape[-1]),
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.scales,
|
self.scales,
|
||||||
|
@ -160,11 +167,14 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
||||||
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
||||||
m = 2 ** m
|
m = 2 ** m
|
||||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
|
||||||
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
|
||||||
if transpose:
|
if transpose:
|
||||||
|
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||||
|
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
|
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
|
||||||
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
else:
|
||||||
|
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||||
|
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
del kn_values
|
del kn_values
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -356,7 +356,6 @@ def silu(x):
|
||||||
return x * tl.sigmoid(x)
|
return x * tl.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
with torch.cuda.device(input.device):
|
with torch.cuda.device(input.device):
|
||||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
|
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
|
||||||
|
@ -414,3 +413,30 @@ class QuantLinearFunction(torch.autograd.Function):
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
return grad_input, None, None, None, None, None, None
|
return grad_input, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
with torch.cuda.device(input.device):
|
||||||
|
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
|
||||||
|
)
|
||||||
|
quant_matmul_248_kernel[grid](
|
||||||
|
input, qweight, output,
|
||||||
|
scales, qzeros, g_idx,
|
||||||
|
input.shape[0], qweight.shape[1], input.shape[1],
|
||||||
|
bits, maxq,
|
||||||
|
input.stride(0), input.stride(1),
|
||||||
|
qweight.stride(0), qweight.stride(1),
|
||||||
|
output.stride(0), output.stride(1),
|
||||||
|
scales.stride(0), qzeros.stride(0)
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
return output
|
||||||
|
|
|
@ -7,6 +7,13 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
TRITON_AVAILABLE = False
|
TRITON_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import autogptq_cuda
|
||||||
|
|
||||||
|
AUTOGPTQ_CUDA_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
AUTOGPTQ_CUDA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
|
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
|
||||||
if use_triton:
|
if use_triton:
|
||||||
|
|
Loading…
Add table
Reference in a new issue