if exllama auto diable fused attention
This commit is contained in:
parent
11afc47f7f
commit
8c7c806d36
2 changed files with 22 additions and 12 deletions
|
@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
|
|||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
|
@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
config = model.config
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, GPTJAttention):
|
||||
continue
|
||||
|
@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
|
@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
setattr(parent, child_name, attn)
|
||||
del m
|
||||
|
||||
return True
|
||||
|
||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
||||
|
|
|
@ -7,6 +7,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -160,6 +163,12 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaAttention):
|
||||
|
@ -169,12 +178,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
elif m.num_heads == m.num_key_value_heads:
|
||||
if m.num_heads == m.num_key_value_heads:
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
@ -240,3 +244,5 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
child_name = name
|
||||
|
||||
setattr(parent, child_name, attn)
|
||||
|
||||
return True
|
Loading…
Add table
Reference in a new issue