Compare commits
2 commits
main
...
support-gr
Author | SHA1 | Date | |
---|---|---|---|
|
8c7c806d36 | ||
|
11afc47f7f |
2 changed files with 110 additions and 60 deletions
|
@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
|
|||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
|
@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
config = model.config
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, GPTJAttention):
|
||||
continue
|
||||
|
@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
|
@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
setattr(parent, child_name, attn)
|
||||
del m
|
||||
|
||||
return True
|
||||
|
||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
||||
|
|
|
@ -2,34 +2,48 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
config,
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if self.head_dim * num_heads != self.hidden_size:
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.qkv_proj = qkv_proj
|
||||
if self.config.pretraining_tp > 1:
|
||||
raise NotImplementedError(f"pretraining_tp of 2 or more is currently not supported.")
|
||||
|
||||
if len(qkv_proj) == 1:
|
||||
self.qkv_mode = 'qkv'
|
||||
self.qkv_proj = qkv_proj[0]
|
||||
elif len(qkv_proj) == 2:
|
||||
self.qkv_mode = 'q,kv'
|
||||
self.q_proj = qkv_proj[0]
|
||||
self.kv_proj = qkv_proj[1]
|
||||
self.o_proj = o_proj
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
|
@ -39,9 +53,9 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
**kwargs
|
||||
|
@ -50,12 +64,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
|
||||
|
||||
if self.qkv_mode == 'qkv':
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
|
||||
elif self.qkv_mode == 'q,kv':
|
||||
query_states = self.q_proj(hidden_states)
|
||||
kv_states = self.kv_proj(hidden_states)
|
||||
key_states, value_states = torch.split(kv_states, self.hidden_size, dim=2)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
|
@ -79,6 +97,10 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if compare_pytorch_version("v2.0.0", op="eq"):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states,
|
||||
|
@ -93,7 +115,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
|
@ -103,7 +125,6 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
@ -142,6 +163,12 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaAttention):
|
||||
|
@ -151,41 +178,61 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
if m.num_heads == m.num_key_value_heads:
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
qlinear_args = (
|
||||
q_proj.bits,
|
||||
q_proj.group_size,
|
||||
q_proj.infeatures,
|
||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.qzeros = qzeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.g_idx = g_idx
|
||||
qkv_layer.bias = bias
|
||||
qkv_layers = [qkv_layer]
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
qweights = torch.cat([k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([k_proj.scales, v_proj.scales], dim=1)
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
bias = torch.cat([k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
qlinear_args = (
|
||||
q_proj.bits,
|
||||
q_proj.group_size,
|
||||
q_proj.infeatures,
|
||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.qzeros = qzeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.g_idx = g_idx
|
||||
qkv_layer.bias = bias
|
||||
|
||||
attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
||||
qlinear_args = (
|
||||
k_proj.bits,
|
||||
k_proj.group_size,
|
||||
k_proj.infeatures,
|
||||
k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
kv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
kv_layer.qweight = qweights
|
||||
kv_layer.qzeros = qzeros
|
||||
kv_layer.scales = scales
|
||||
kv_layer.g_idx = g_idx
|
||||
kv_layer.bias = bias
|
||||
qkv_layers = [q_proj, kv_layer]
|
||||
attn = cls(m.config, qkv_layers, m.o_proj, m.rotary_emb)
|
||||
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
|
@ -198,5 +245,4 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
|
||||
setattr(parent, child_name, attn)
|
||||
|
||||
|
||||
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
|
||||
return True
|
Loading…
Add table
Reference in a new issue