diff --git a/auto_gptq/nn_modules/fused_gptj_attn.py b/auto_gptq/nn_modules/fused_gptj_attn.py index 760b61f..bc39347 100644 --- a/auto_gptq/nn_modules/fused_gptj_attn.py +++ b/auto_gptq/nn_modules/fused_gptj_attn.py @@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention from ._fused_base import FusedBaseAttentionModule from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear +from logging import getLogger +logger = getLogger(__name__) def fixed_pos_embedding(x, seq_dim=1, seq_len=None): dim = x.shape[-1] @@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): config = model.config QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama) + if QuantLinear.QUANT_TYPE == "exllama" and desc_act: + # See fused_llama_attn.py comment + logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.") + return False + for name, m in model.named_modules(): if not isinstance(m, GPTJAttention): continue @@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) if QuantLinear.QUANT_TYPE == "exllama": - if desc_act: - # See fused_llama_attn.py comment - raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.") - else: - g_idx = None + g_idx = None else: g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) @@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): setattr(parent, child_name, attn) del m + return True __all__ = ["FusedGPTJAttentionForQuantizedModel"] diff --git a/auto_gptq/nn_modules/fused_llama_attn.py b/auto_gptq/nn_modules/fused_llama_attn.py index d0e01cc..5469ef1 100644 --- a/auto_gptq/nn_modules/fused_llama_attn.py +++ b/auto_gptq/nn_modules/fused_llama_attn.py @@ -7,6 +7,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar from ._fused_base import FusedBaseAttentionModule from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear +from logging import getLogger +logger = getLogger(__name__) + class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -160,7 +163,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama) - + if QuantLinear.QUANT_TYPE == "exllama" and desc_act: + # TODO: support it. The issue lies maybe in the line: + # int groups = qzeros.size(0); + # in exllama_ext.cpp + logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.") + return False + for name, m in model.named_modules(): if not isinstance(m, LlamaAttention): continue @@ -169,12 +178,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): k_proj = m.k_proj v_proj = m.v_proj - if QuantLinear.QUANT_TYPE == "exllama" and desc_act: - # TODO: support it. The issue lies maybe in the line: - # int groups = qzeros.size(0); - # in exllama_ext.cpp - raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.") - elif m.num_heads == m.num_key_value_heads: + if m.num_heads == m.num_key_value_heads: qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) @@ -240,3 +244,5 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): child_name = name setattr(parent, child_name, attn) + + return True \ No newline at end of file