add trainable mode

This commit is contained in:
PanQiWei 2023-05-26 13:11:30 +08:00
parent fe5f5d12ed
commit 2b532f9453
12 changed files with 169 additions and 22 deletions

View file

@ -23,7 +23,7 @@ from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE
logger = getLogger(__name__)
@ -77,7 +77,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
def __init__(
self,
model: PreTrainedModel,
quantized: bool,
quantize_config: BaseQuantizeConfig,
is_triton_backend: bool = False,
injected_fused_attention: bool = False,
injected_fused_mlp: bool = False,
trainable: bool = False
):
super().__init__()
self.model = model
@ -86,6 +95,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.quantize_config = quantize_config
self.config = self.model.config
self.is_triton_backend = is_triton_backend
self.injected_fused_attention = injected_fused_attention
self.injected_fused_mlp = injected_fused_mlp
self.trainable = trainable
@property
def quantized(self):
return self._quantized
@ -510,6 +524,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
use_safetensors: bool = False,
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
**kwargs
):
"""load quantized model from local disk"""
@ -517,7 +532,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
logger.warning("triton is not installed, reset use_triton to False")
use_triton = False
# == step1: prepare configs and file names == #
# == step1: prepare configs and file names, and check values of arguments passed in == #
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
@ -543,6 +558,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
else:
raise FileNotFoundError(f"Could not find model at {model_save_name}")
if not use_triton and trainable:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs):
pass
@ -578,7 +596,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantize_config.group_size,
use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act
desc_act=quantize_config.desc_act,
trainable=trainable
)
model.tie_weights()
@ -643,6 +662,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
# == step5: (optional) inject optimized module == #
if inject_fused_attention:
if cls.fused_attn_module_type is None:
inject_fused_attention = False
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
else:
cls.fused_attn_module_type.inject_to_model(
@ -650,10 +670,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
use_triton=use_triton,
group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act
desc_act=quantize_config.desc_act,
trainable=trainable
)
if inject_fused_mlp:
if cls.fused_mlp_module_type is None:
inject_fused_mlp = False
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
else:
cls.fused_mlp_module_type.inject_to_model(
@ -670,7 +692,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
return cls(model, True, quantize_config)
return cls(
model,
True,
quantize_config,
is_triton_backend=use_triton,
injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and use_triton,
trainable=trainable
)
def warmup_triton(self, enabled: bool = True):
if not enabled:
@ -685,6 +715,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if self.fused_mlp_module_type is not None:
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
def enable_trainable_mode(self, enabled: bool = True):
if not self.is_triton_backend and enabled:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
for n, m in self.model.named_modules():
if hasattr(m, "trainable"):
setattr(m, "trainable", enabled)
def disable_trainable_mode(self):
self.enable_trainable_mode(enabled=False)
def __getattr__(self, item):
try:
return super().__getattr__(item)

View file

@ -50,7 +50,17 @@ def get_module_by_name_suffix(model, module_name: str):
return module
def make_quant(module, names, bits, group_size, name='', use_triton=False, use_cuda_fp16=True, desc_act=False):
def make_quant(
module,
names,
bits,
group_size,
name='',
use_triton=False,
use_cuda_fp16=True,
desc_act=False,
trainable=False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
if isinstance(module, QuantLinear):
@ -71,13 +81,25 @@ def make_quant(module, names, bits, group_size, name='', use_triton=False, use_c
in_features = tmp.weight.shape[0]
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16)
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
)
else:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True)
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
new_layer.device = ori_layer_device
setattr(module, attr, new_layer.to(ori_layer_device))
for name1, child in module.named_children():
make_quant(child, names, bits, group_size, name + '.' + name1 if name != '' else name1, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16,desc_act=desc_act)
make_quant(
child,
names,
bits,
group_size,
name + '.' + name1 if name != '' else name1,
use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
trainable=trainable
)
def pack_model(

View file

@ -65,6 +65,7 @@ class AutoGPTQForCausalLM:
use_safetensors: bool = False,
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir, trust_remote_code)
@ -85,6 +86,7 @@ class AutoGPTQForCausalLM:
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
**keywords
)

View file

@ -18,7 +18,16 @@ class FusedBaseModule(nn.Module, TritonModuleMixin):
class FusedBaseAttentionModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
raise NotImplementedError()
@classmethod

View file

@ -226,7 +226,16 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
return outputs # a, present, (attentions)
@classmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
@ -253,7 +262,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = dict()
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)

View file

@ -126,7 +126,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
return attn_output, attn_weights, past_key_value
@classmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs):
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
@ -153,7 +162,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = dict()
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)

View file

@ -35,6 +35,8 @@ class GeneralQuantLinear(nn.Linear):
if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward
@classmethod

View file

@ -26,10 +26,13 @@ class QuantLinear(nn.Module):
outfeatures,
bias,
kernel_switch_threshold=128,
trainable=False
):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
self.infeatures = infeatures
self.outfeatures = outfeatures
@ -76,6 +79,8 @@ class QuantLinear(nn.Module):
if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda_available = False
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):

View file

@ -16,6 +16,7 @@ except ImportError:
logger.warning('CUDA extension not installed.')
_autogptq_cuda_available = False
class QuantLinear(nn.Module):
def __init__(
self,
@ -25,12 +26,15 @@ class QuantLinear(nn.Module):
outfeatures,
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128
kernel_switch_threshold=128,
trainable=False
):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
@ -80,6 +84,8 @@ class QuantLinear(nn.Module):
if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda_available = False
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()

View file

@ -11,7 +11,10 @@ from ..triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
try:
from ..triton_utils.kernels import quant_matmul_248, transpose_quant_matmul_248, QuantLinearFunction
from ..triton_utils.kernels import (
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
QuantLinearFunction, QuantLinearInferenceOnlyFunction
)
except ImportError:
logger.error('triton not installed.')
raise
@ -24,7 +27,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
group_size,
infeatures,
outfeatures,
bias
bias,
trainable=False
):
super().__init__()
if bits not in [2, 4, 8]:
@ -58,6 +62,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
else:
self.bias = None
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
@ -122,7 +128,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
@ -160,11 +167,14 @@ class QuantLinear(nn.Module, TritonModuleMixin):
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2 ** m
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
if transpose:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
if transpose:
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
else:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values

View file

@ -356,7 +356,6 @@ def silu(x):
return x * tl.sigmoid(x)
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
@ -414,3 +413,30 @@ class QuantLinearFunction(torch.autograd.Function):
if ctx.needs_input_grad[0]:
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
)
quant_matmul_248_kernel[grid](
input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1],
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output

View file

@ -7,6 +7,13 @@ try:
except ImportError:
TRITON_AVAILABLE = False
try:
import autogptq_cuda
AUTOGPTQ_CUDA_AVAILABLE = True
except:
AUTOGPTQ_CUDA_AVAILABLE = False
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
if use_triton: