From f3a5a79b7bf6cf00f4fde89d893ad6ac4f45fefd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 30 Aug 2023 19:20:18 +0800 Subject: [PATCH] Fix g_idx in fused kernel --- auto_gptq/modeling/_utils.py | 2 +- auto_gptq/nn_modules/fused_llama_attn.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py index 6c2fd38..c005258 100644 --- a/auto_gptq/modeling/_utils.py +++ b/auto_gptq/modeling/_utils.py @@ -301,7 +301,7 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in if max_input_length is None: max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH else: - max_input_len = max_input_len + max_input_len = max_input_length else: if max_input_length is not None: logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.") diff --git a/auto_gptq/nn_modules/fused_llama_attn.py b/auto_gptq/nn_modules/fused_llama_attn.py index 185770d..6dbd03a 100644 --- a/auto_gptq/nn_modules/fused_llama_attn.py +++ b/auto_gptq/nn_modules/fused_llama_attn.py @@ -164,7 +164,8 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): else: g_idx = None else: - g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + assert(torch.equal(q_proj.g_idx, k_proj.g_idx) and torch.equal(q_proj.g_idx, v_proj.g_idx)) + g_idx = q_proj.g_idx bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None