if exllama auto diable fused attention
This commit is contained in:
parent
11afc47f7f
commit
8c7c806d36
2 changed files with 22 additions and 12 deletions
|
@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
|
||||||
from ._fused_base import FusedBaseAttentionModule
|
from ._fused_base import FusedBaseAttentionModule
|
||||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||||
dim = x.shape[-1]
|
dim = x.shape[-1]
|
||||||
|
@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
config = model.config
|
config = model.config
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||||
|
|
||||||
|
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||||
|
# See fused_llama_attn.py comment
|
||||||
|
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||||
|
return False
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, GPTJAttention):
|
if not isinstance(m, GPTJAttention):
|
||||||
continue
|
continue
|
||||||
|
@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||||
|
|
||||||
if QuantLinear.QUANT_TYPE == "exllama":
|
if QuantLinear.QUANT_TYPE == "exllama":
|
||||||
if desc_act:
|
g_idx = None
|
||||||
# See fused_llama_attn.py comment
|
|
||||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
|
||||||
else:
|
|
||||||
g_idx = None
|
|
||||||
else:
|
else:
|
||||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||||
|
|
||||||
|
@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
setattr(parent, child_name, attn)
|
setattr(parent, child_name, attn)
|
||||||
del m
|
del m
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
||||||
|
|
|
@ -7,6 +7,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
||||||
from ._fused_base import FusedBaseAttentionModule
|
from ._fused_base import FusedBaseAttentionModule
|
||||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
@ -160,7 +163,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||||
"""
|
"""
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||||
|
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||||
|
# TODO: support it. The issue lies maybe in the line:
|
||||||
|
# int groups = qzeros.size(0);
|
||||||
|
# in exllama_ext.cpp
|
||||||
|
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||||
|
return False
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, LlamaAttention):
|
if not isinstance(m, LlamaAttention):
|
||||||
continue
|
continue
|
||||||
|
@ -169,12 +178,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
k_proj = m.k_proj
|
k_proj = m.k_proj
|
||||||
v_proj = m.v_proj
|
v_proj = m.v_proj
|
||||||
|
|
||||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
if m.num_heads == m.num_key_value_heads:
|
||||||
# TODO: support it. The issue lies maybe in the line:
|
|
||||||
# int groups = qzeros.size(0);
|
|
||||||
# in exllama_ext.cpp
|
|
||||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
|
||||||
elif m.num_heads == m.num_key_value_heads:
|
|
||||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||||
|
@ -240,3 +244,5 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
child_name = name
|
child_name = name
|
||||||
|
|
||||||
setattr(parent, child_name, attn)
|
setattr(parent, child_name, attn)
|
||||||
|
|
||||||
|
return True
|
Loading…
Add table
Reference in a new issue