Fix g_idx in fused kernel

This commit is contained in:
楚天翔 2023-08-30 19:20:18 +08:00
parent 604c96144f
commit f3a5a79b7b
2 changed files with 3 additions and 2 deletions

View file

@ -301,7 +301,7 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
if max_input_length is None:
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
else:
max_input_len = max_input_len
max_input_len = max_input_length
else:
if max_input_length is not None:
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")

View file

@ -164,7 +164,8 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
assert(torch.equal(q_proj.g_idx, k_proj.g_idx) and torch.equal(q_proj.g_idx, v_proj.g_idx))
g_idx = q_proj.g_idx
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None