if exllama auto diable fused attention

This commit is contained in:
qwopqwop200 2023-08-07 19:24:16 +09:00 committed by GitHub
parent 11afc47f7f
commit 8c7c806d36
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 12 deletions

View file

@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# See fused_llama_attn.py comment
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
continue
@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
setattr(parent, child_name, attn)
del m
return True
__all__ = ["FusedGPTJAttentionForQuantizedModel"]

View file

@ -7,6 +7,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -160,6 +163,12 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
@ -169,12 +178,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
k_proj = m.k_proj
v_proj = m.v_proj
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
elif m.num_heads == m.num_key_value_heads:
if m.num_heads == m.num_key_value_heads:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
@ -240,3 +244,5 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
child_name = name
setattr(parent, child_name, attn)
return True