add trainable mode

This commit is contained in:
PanQiWei 2023-05-26 13:11:30 +08:00
parent fe5f5d12ed
commit 2b532f9453
12 changed files with 169 additions and 22 deletions

View file

@ -23,7 +23,7 @@ from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..quantization import GPTQ from ..quantization import GPTQ
from ..utils.data_utils import collate_data from ..utils.data_utils import collate_data
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE
logger = getLogger(__name__) logger = getLogger(__name__)
@ -77,7 +77,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig): def __init__(
self,
model: PreTrainedModel,
quantized: bool,
quantize_config: BaseQuantizeConfig,
is_triton_backend: bool = False,
injected_fused_attention: bool = False,
injected_fused_mlp: bool = False,
trainable: bool = False
):
super().__init__() super().__init__()
self.model = model self.model = model
@ -86,6 +95,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.quantize_config = quantize_config self.quantize_config = quantize_config
self.config = self.model.config self.config = self.model.config
self.is_triton_backend = is_triton_backend
self.injected_fused_attention = injected_fused_attention
self.injected_fused_mlp = injected_fused_mlp
self.trainable = trainable
@property @property
def quantized(self): def quantized(self):
return self._quantized return self._quantized
@ -510,6 +524,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
use_safetensors: bool = False, use_safetensors: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
warmup_triton: bool = False, warmup_triton: bool = False,
trainable: bool = False,
**kwargs **kwargs
): ):
"""load quantized model from local disk""" """load quantized model from local disk"""
@ -517,7 +532,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
logger.warning("triton is not installed, reset use_triton to False") logger.warning("triton is not installed, reset use_triton to False")
use_triton = False use_triton = False
# == step1: prepare configs and file names == # # == step1: prepare configs and file names, and check values of arguments passed in == #
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
if config.model_type not in SUPPORTED_MODELS: if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
@ -543,6 +558,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
else: else:
raise FileNotFoundError(f"Could not find model at {model_save_name}") raise FileNotFoundError(f"Could not find model at {model_save_name}")
if not use_triton and trainable:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == # # == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs): def skip(*args, **kwargs):
pass pass
@ -578,7 +596,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantize_config.group_size, quantize_config.group_size,
use_triton=use_triton, use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16, use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act desc_act=quantize_config.desc_act,
trainable=trainable
) )
model.tie_weights() model.tie_weights()
@ -643,6 +662,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
# == step5: (optional) inject optimized module == # # == step5: (optional) inject optimized module == #
if inject_fused_attention: if inject_fused_attention:
if cls.fused_attn_module_type is None: if cls.fused_attn_module_type is None:
inject_fused_attention = False
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.") logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
else: else:
cls.fused_attn_module_type.inject_to_model( cls.fused_attn_module_type.inject_to_model(
@ -650,10 +670,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
use_triton=use_triton, use_triton=use_triton,
group_size=quantize_config.group_size, group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16, use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act desc_act=quantize_config.desc_act,
trainable=trainable
) )
if inject_fused_mlp: if inject_fused_mlp:
if cls.fused_mlp_module_type is None: if cls.fused_mlp_module_type is None:
inject_fused_mlp = False
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.") logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
else: else:
cls.fused_mlp_module_type.inject_to_model( cls.fused_mlp_module_type.inject_to_model(
@ -670,7 +692,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if inject_fused_mlp and cls.fused_mlp_module_type is not None: if inject_fused_mlp and cls.fused_mlp_module_type is not None:
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen) cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
return cls(model, True, quantize_config) return cls(
model,
True,
quantize_config,
is_triton_backend=use_triton,
injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and use_triton,
trainable=trainable
)
def warmup_triton(self, enabled: bool = True): def warmup_triton(self, enabled: bool = True):
if not enabled: if not enabled:
@ -685,6 +715,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if self.fused_mlp_module_type is not None: if self.fused_mlp_module_type is not None:
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen) self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
def enable_trainable_mode(self, enabled: bool = True):
if not self.is_triton_backend and enabled:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
for n, m in self.model.named_modules():
if hasattr(m, "trainable"):
setattr(m, "trainable", enabled)
def disable_trainable_mode(self):
self.enable_trainable_mode(enabled=False)
def __getattr__(self, item): def __getattr__(self, item):
try: try:
return super().__getattr__(item) return super().__getattr__(item)

View file

@ -50,7 +50,17 @@ def get_module_by_name_suffix(model, module_name: str):
return module return module
def make_quant(module, names, bits, group_size, name='', use_triton=False, use_cuda_fp16=True, desc_act=False): def make_quant(
module,
names,
bits,
group_size,
name='',
use_triton=False,
use_cuda_fp16=True,
desc_act=False,
trainable=False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size) QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
if isinstance(module, QuantLinear): if isinstance(module, QuantLinear):
@ -71,13 +81,25 @@ def make_quant(module, names, bits, group_size, name='', use_triton=False, use_c
in_features = tmp.weight.shape[0] in_features = tmp.weight.shape[0]
out_features = tmp.weight.shape[1] out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton: if (not(desc_act) or group_size == -1) and not use_triton:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16) new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
)
else: else:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True) new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
new_layer.device = ori_layer_device new_layer.device = ori_layer_device
setattr(module, attr, new_layer.to(ori_layer_device)) setattr(module, attr, new_layer.to(ori_layer_device))
for name1, child in module.named_children(): for name1, child in module.named_children():
make_quant(child, names, bits, group_size, name + '.' + name1 if name != '' else name1, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16,desc_act=desc_act) make_quant(
child,
names,
bits,
group_size,
name + '.' + name1 if name != '' else name1,
use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
trainable=trainable
)
def pack_model( def pack_model(

View file

@ -65,6 +65,7 @@ class AutoGPTQForCausalLM:
use_safetensors: bool = False, use_safetensors: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
warmup_triton: bool = False, warmup_triton: bool = False,
trainable: bool = False,
**kwargs **kwargs
) -> BaseGPTQForCausalLM: ) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir, trust_remote_code) model_type = check_and_get_model_type(save_dir, trust_remote_code)
@ -85,6 +86,7 @@ class AutoGPTQForCausalLM:
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton, warmup_triton=warmup_triton,
trainable=trainable,
**keywords **keywords
) )

View file

@ -18,7 +18,16 @@ class FusedBaseModule(nn.Module, TritonModuleMixin):
class FusedBaseAttentionModule(FusedBaseModule): class FusedBaseAttentionModule(FusedBaseModule):
@classmethod @classmethod
@abstractmethod @abstractmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs): def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod

View file

@ -226,7 +226,16 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
return outputs # a, present, (attentions) return outputs # a, present, (attentions)
@classmethod @classmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs): def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
config = model.config config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size) QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
@ -253,7 +262,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False, True if q_proj.bias is not None else False,
) )
qlinear_kwargs = dict() qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton: if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs) qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)

View file

@ -126,7 +126,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
@classmethod @classmethod
def inject_to_model(cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, **kwargs): def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
""" """
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
""" """
@ -153,7 +162,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False, True if q_proj.bias is not None else False,
) )
qlinear_kwargs = dict() qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton: if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs) qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)

View file

@ -35,6 +35,8 @@ class GeneralQuantLinear(nn.Linear):
if hasattr(quant_linear_module, "autogptq_cuda_available"): if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward self.forward = quant_linear_module.forward
@classmethod @classmethod

View file

@ -26,10 +26,13 @@ class QuantLinear(nn.Module):
outfeatures, outfeatures,
bias, bias,
kernel_switch_threshold=128, kernel_switch_threshold=128,
trainable=False
): ):
super().__init__() super().__init__()
if bits not in [2, 3, 4, 8]: if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.") raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
self.infeatures = infeatures self.infeatures = infeatures
self.outfeatures = outfeatures self.outfeatures = outfeatures
@ -76,6 +79,8 @@ class QuantLinear(nn.Module):
if infeatures % 256 != 0 or outfeatures % 256 != 0: if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda_available = False self.autogptq_cuda_available = False
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone() W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d): if isinstance(linear, nn.Conv2d):

View file

@ -16,6 +16,7 @@ except ImportError:
logger.warning('CUDA extension not installed.') logger.warning('CUDA extension not installed.')
_autogptq_cuda_available = False _autogptq_cuda_available = False
class QuantLinear(nn.Module): class QuantLinear(nn.Module):
def __init__( def __init__(
self, self,
@ -25,12 +26,15 @@ class QuantLinear(nn.Module):
outfeatures, outfeatures,
bias, bias,
use_cuda_fp16=True, use_cuda_fp16=True,
kernel_switch_threshold=128 kernel_switch_threshold=128,
trainable=False
): ):
super().__init__() super().__init__()
if bits not in [2, 3, 4, 8]: if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.") raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
raise NotImplementedError("QuantLinear with cuda backend not support trainable mode yet.")
self.infeatures = infeatures self.infeatures = infeatures
self.outfeatures = outfeatures self.outfeatures = outfeatures
self.bits = bits self.bits = bits
@ -80,6 +84,8 @@ class QuantLinear(nn.Module):
if infeatures % 256 != 0 or outfeatures % 256 != 0: if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda_available = False self.autogptq_cuda_available = False
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx): def pack(self, linear, scales, zeros, g_idx):
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()

View file

@ -11,7 +11,10 @@ from ..triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__) logger = getLogger(__name__)
try: try:
from ..triton_utils.kernels import quant_matmul_248, transpose_quant_matmul_248, QuantLinearFunction from ..triton_utils.kernels import (
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
QuantLinearFunction, QuantLinearInferenceOnlyFunction
)
except ImportError: except ImportError:
logger.error('triton not installed.') logger.error('triton not installed.')
raise raise
@ -24,7 +27,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
group_size, group_size,
infeatures, infeatures,
outfeatures, outfeatures,
bias bias,
trainable=False
): ):
super().__init__() super().__init__()
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
@ -58,6 +62,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
else: else:
self.bias = None self.bias = None
self.trainable = trainable
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone() W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d): if isinstance(linear, nn.Conv2d):
@ -122,7 +128,8 @@ class QuantLinear(nn.Module, TritonModuleMixin):
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,) out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply( quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]), x.reshape(-1, x.shape[-1]),
self.qweight, self.qweight,
self.scales, self.scales,
@ -160,11 +167,14 @@ class QuantLinear(nn.Module, TritonModuleMixin):
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2 ** m m = 2 ** m
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items(): for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
if transpose: if transpose:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
a = torch.randn(m, n, dtype=torch.float16, device=model.device) a = torch.randn(m, n, dtype=torch.float16, device=model.device)
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
else:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values del kn_values

View file

@ -356,7 +356,6 @@ def silu(x):
return x * tl.sigmoid(x) return x * tl.sigmoid(x)
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device): with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype) output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
@ -414,3 +413,30 @@ class QuantLinearFunction(torch.autograd.Function):
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None return grad_input, None, None, None, None, None, None
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
)
quant_matmul_248_kernel[grid](
input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1],
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output

View file

@ -7,6 +7,13 @@ try:
except ImportError: except ImportError:
TRITON_AVAILABLE = False TRITON_AVAILABLE = False
try:
import autogptq_cuda
AUTOGPTQ_CUDA_AVAILABLE = True
except:
AUTOGPTQ_CUDA_AVAILABLE = False
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int): def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
if use_triton: if use_triton: