add trainable mode
This commit is contained in:
parent
fe5f5d12ed
commit
2b532f9453
12 changed files with 169 additions and 22 deletions
|
@ -23,7 +23,7 @@ from ..nn_modules.qlinear import GeneralQuantLinear
|
|||
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
||||
from ..quantization import GPTQ
|
||||
from ..utils.data_utils import collate_data
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -77,7 +77,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
||||
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
||||
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
quantized: bool,
|
||||
quantize_config: BaseQuantizeConfig,
|
||||
is_triton_backend: bool = False,
|
||||
injected_fused_attention: bool = False,
|
||||
injected_fused_mlp: bool = False,
|
||||
trainable: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
@ -86,6 +95,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
self.quantize_config = quantize_config
|
||||
self.config = self.model.config
|
||||
|
||||
self.is_triton_backend = is_triton_backend
|
||||
self.injected_fused_attention = injected_fused_attention
|
||||
self.injected_fused_mlp = injected_fused_mlp
|
||||
self.trainable = trainable
|
||||
|
||||
@property
|
||||
def quantized(self):
|
||||
return self._quantized
|
||||
|
@ -510,6 +524,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""load quantized model from local disk"""
|
||||
|
@ -517,7 +532,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
logger.warning("triton is not installed, reset use_triton to False")
|
||||
use_triton = False
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
# == step1: prepare configs and file names, and check values of arguments passed in == #
|
||||
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
@ -543,6 +558,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
else:
|
||||
raise FileNotFoundError(f"Could not find model at {model_save_name}")
|
||||
|
||||
if not use_triton and trainable:
|
||||
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||
|
||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
@ -578,7 +596,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable
|
||||
)
|
||||
model.tie_weights()
|
||||
|
||||
|
@ -643,6 +662,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
# == step5: (optional) inject optimized module == #
|
||||
if inject_fused_attention:
|
||||
if cls.fused_attn_module_type is None:
|
||||
inject_fused_attention = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
||||
else:
|
||||
cls.fused_attn_module_type.inject_to_model(
|
||||
|
@ -650,10 +670,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
use_triton=use_triton,
|
||||
group_size=quantize_config.group_size,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable
|
||||
)
|
||||
if inject_fused_mlp:
|
||||
if cls.fused_mlp_module_type is None:
|
||||
inject_fused_mlp = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
||||
else:
|
||||
cls.fused_mlp_module_type.inject_to_model(
|
||||
|
@ -670,7 +692,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
||||
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
||||
|
||||
return cls(model, True, quantize_config)
|
||||
return cls(
|
||||
model,
|
||||
True,
|
||||
quantize_config,
|
||||
is_triton_backend=use_triton,
|
||||
injected_fused_attention=inject_fused_attention,
|
||||
injected_fused_mlp=inject_fused_mlp and use_triton,
|
||||
trainable=trainable
|
||||
)
|
||||
|
||||
def warmup_triton(self, enabled: bool = True):
|
||||
if not enabled:
|
||||
|
@ -685,6 +715,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if self.fused_mlp_module_type is not None:
|
||||
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
||||
|
||||
def enable_trainable_mode(self, enabled: bool = True):
|
||||
if not self.is_triton_backend and enabled:
|
||||
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "trainable"):
|
||||
setattr(m, "trainable", enabled)
|
||||
|
||||
def disable_trainable_mode(self):
|
||||
self.enable_trainable_mode(enabled=False)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
|
|
|
@ -50,7 +50,17 @@ def get_module_by_name_suffix(model, module_name: str):
|
|||
return module
|
||||
|
||||
|
||||
def make_quant(module, names, bits, group_size, name='', use_triton=False, use_cuda_fp16=True, desc_act=False):
|
||||
def make_quant(
|
||||
module,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
name='',
|
||||
use_triton=False,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
||||
|
||||
if isinstance(module, QuantLinear):
|
||||
|
@ -71,13 +81,25 @@ def make_quant(module, names, bits, group_size, name='', use_triton=False, use_c
|
|||
in_features = tmp.weight.shape[0]
|
||||
out_features = tmp.weight.shape[1]
|
||||
if (not(desc_act) or group_size == -1) and not use_triton:
|
||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16)
|
||||
new_layer = QuantLinear(
|
||||
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
||||
)
|
||||
else:
|
||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True)
|
||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
|
||||
new_layer.device = ori_layer_device
|
||||
setattr(module, attr, new_layer.to(ori_layer_device))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, group_size, name + '.' + name1 if name != '' else name1, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16,desc_act=desc_act)
|
||||
make_quant(
|
||||
child,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
name + '.' + name1 if name != '' else name1,
|
||||
use_triton=use_triton,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=desc_act,
|
||||
trainable=trainable
|
||||
)
|
||||
|
||||
|
||||
def pack_model(
|
||||
|
|
|
@ -65,6 +65,7 @@ class AutoGPTQForCausalLM:
|
|||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
**kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(save_dir, trust_remote_code)
|
||||
|
@ -85,6 +86,7 @@ class AutoGPTQForCausalLM:
|
|||
use_safetensors=use_safetensors,
|
||||
trust_remote_code=trust_remote_code,
|
||||
warmup_triton=warmup_triton,
|
||||
trainable=trainable,
|
||||
**keywords
|
||||
)
|
||||
|
||||
|
|
|
@ -18,7 +18,16 @@ class FusedBaseModule(nn.Module, TritonModuleMixin):
|
|||
class FusedBaseAttentionModule(FusedBaseModule):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
**kwargs
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -226,7 +226,16 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
return outputs # a, present, (attentions)
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
**kwargs
|
||||
):
|
||||
config = model.config
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
|
||||
|
||||
|
@ -253,7 +262,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = dict()
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
|
|
|
@ -126,7 +126,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
|
@ -153,7 +162,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = dict()
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
|
|
|
@ -35,6 +35,8 @@ class GeneralQuantLinear(nn.Linear):
|
|||
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
||||
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
||||
|
||||
self.trainable = quant_linear_module.trainable
|
||||
|
||||
self.forward = quant_linear_module.forward
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -26,10 +26,13 @@ class QuantLinear(nn.Module):
|
|||
outfeatures,
|
||||
bias,
|
||||
kernel_switch_threshold=128,
|
||||
trainable=False
|
||||
):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
|
@ -76,6 +79,8 @@ class QuantLinear(nn.Module):
|
|||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||
self.autogptq_cuda_available = False
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
|
|
|
@ -16,6 +16,7 @@ except ImportError:
|
|||
logger.warning('CUDA extension not installed.')
|
||||
_autogptq_cuda_available = False
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -25,12 +26,15 @@ class QuantLinear(nn.Module):
|
|||
outfeatures,
|
||||
bias,
|
||||
use_cuda_fp16=True,
|
||||
kernel_switch_threshold=128
|
||||
kernel_switch_threshold=128,
|
||||
trainable=False
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
|
@ -80,6 +84,8 @@ class QuantLinear(nn.Module):
|
|||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||
self.autogptq_cuda_available = False
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx):
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
|
|
|
@ -11,7 +11,10 @@ from ..triton_utils.mixin import TritonModuleMixin
|
|||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ..triton_utils.kernels import quant_matmul_248, transpose_quant_matmul_248, QuantLinearFunction
|
||||
from ..triton_utils.kernels import (
|
||||
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
|
||||
QuantLinearFunction, QuantLinearInferenceOnlyFunction
|
||||
)
|
||||
except ImportError:
|
||||
logger.error('triton not installed.')
|
||||
raise
|
||||
|
@ -24,7 +27,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
group_size,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias
|
||||
bias,
|
||||
trainable=False
|
||||
):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
|
@ -58,6 +62,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
else:
|
||||
self.bias = None
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
|
@ -122,7 +128,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
|
||||
out = quant_linear_fn.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
|
@ -160,11 +167,14 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
||||
m = 2 ** m
|
||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
if transpose:
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
|
||||
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
del kn_values
|
||||
|
||||
|
||||
|
|
|
@ -356,7 +356,6 @@ def silu(x):
|
|||
return x * tl.sigmoid(x)
|
||||
|
||||
|
||||
|
||||
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
|
||||
|
@ -414,3 +413,30 @@ class QuantLinearFunction(torch.autograd.Function):
|
|||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
|
||||
)
|
||||
quant_matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales, qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], input.shape[1],
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return output
|
||||
|
|
|
@ -7,6 +7,13 @@ try:
|
|||
except ImportError:
|
||||
TRITON_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import autogptq_cuda
|
||||
|
||||
AUTOGPTQ_CUDA_AVAILABLE = True
|
||||
except:
|
||||
AUTOGPTQ_CUDA_AVAILABLE = False
|
||||
|
||||
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
|
||||
if use_triton:
|
||||
|
|
Loading…
Add table
Reference in a new issue