if exllama auto diable fused attention
This commit is contained in:
parent
ad5b0d72ee
commit
6b1ceb1897
2 changed files with 20 additions and 15 deletions
|
@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
|
||||||
from ._fused_base import FusedBaseAttentionModule
|
from ._fused_base import FusedBaseAttentionModule
|
||||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||||
dim = x.shape[-1]
|
dim = x.shape[-1]
|
||||||
|
@ -240,6 +242,10 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||||
|
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||||
|
# See fused_llama_attn.py comment
|
||||||
|
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||||
|
return False
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, GPTJAttention):
|
if not isinstance(m, GPTJAttention):
|
||||||
|
@ -256,10 +262,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||||
|
|
||||||
if QuantLinear.QUANT_TYPE == "exllama":
|
if QuantLinear.QUANT_TYPE == "exllama":
|
||||||
if desc_act:
|
|
||||||
# See fused_llama_attn.py comment
|
|
||||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
|
||||||
else:
|
|
||||||
g_idx = None
|
g_idx = None
|
||||||
else:
|
else:
|
||||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||||
|
@ -297,6 +299,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
|
|
||||||
setattr(parent, child_name, attn)
|
setattr(parent, child_name, attn)
|
||||||
del m
|
del m
|
||||||
|
return True
|
||||||
|
|
||||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
||||||
|
|
|
@ -7,6 +7,8 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
||||||
from ._fused_base import FusedBaseAttentionModule
|
from ._fused_base import FusedBaseAttentionModule
|
||||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
@ -142,6 +144,12 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||||
"""
|
"""
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||||
|
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||||
|
# TODO: support it. The issue lies maybe in the line:
|
||||||
|
# int groups = qzeros.size(0);
|
||||||
|
# in exllama_ext.cpp
|
||||||
|
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||||
|
return False
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, LlamaAttention):
|
if not isinstance(m, LlamaAttention):
|
||||||
|
@ -156,12 +164,6 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||||
|
|
||||||
if QuantLinear.QUANT_TYPE == "exllama":
|
if QuantLinear.QUANT_TYPE == "exllama":
|
||||||
if desc_act:
|
|
||||||
# TODO: support it. The issue lies maybe in the line:
|
|
||||||
# int groups = qzeros.size(0);
|
|
||||||
# in exllama_ext.cpp
|
|
||||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
|
||||||
else:
|
|
||||||
g_idx = None
|
g_idx = None
|
||||||
else:
|
else:
|
||||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||||
|
@ -197,6 +199,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
child_name = name
|
child_name = name
|
||||||
|
|
||||||
setattr(parent, child_name, attn)
|
setattr(parent, child_name, attn)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
|
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue