if exllama auto diable fused attention

This commit is contained in:
qwopqwop200 2023-09-06 18:14:04 +09:00 committed by GitHub
parent ad5b0d72ee
commit 6b1ceb1897
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 15 deletions

View file

@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
@ -240,7 +242,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# See fused_llama_attn.py comment
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
continue
@ -256,11 +262,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
@ -297,6 +299,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
setattr(parent, child_name, attn)
del m
return True
__all__ = ["FusedGPTJAttentionForQuantizedModel"]

View file

@ -7,6 +7,8 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -142,7 +144,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
@ -156,13 +164,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
@ -197,6 +199,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
child_name = name
setattr(parent, child_name, attn)
return True
__all__ = ["FusedLlamaAttentionForQuantizedModel"]