Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
qwopqwop200
8c7c806d36
if exllama auto diable fused attention 2023-08-07 19:24:16 +09:00
qwopqwop200
11afc47f7f
support gqa 2023-08-07 19:00:05 +09:00
2 changed files with 110 additions and 60 deletions

View file

@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
from ._fused_base import FusedBaseAttentionModule from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
def fixed_pos_embedding(x, seq_dim=1, seq_len=None): def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1] dim = x.shape[-1]
@ -241,6 +243,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
config = model.config config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama) QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# See fused_llama_attn.py comment
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules(): for name, m in model.named_modules():
if not isinstance(m, GPTJAttention): if not isinstance(m, GPTJAttention):
continue continue
@ -256,11 +263,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama": if QuantLinear.QUANT_TYPE == "exllama":
if desc_act: g_idx = None
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else: else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
@ -298,5 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
setattr(parent, child_name, attn) setattr(parent, child_name, attn)
del m del m
return True
__all__ = ["FusedGPTJAttentionForQuantizedModel"] __all__ = ["FusedGPTJAttentionForQuantizedModel"]

View file

@ -2,34 +2,48 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
from ._fused_base import FusedBaseAttentionModule from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from logging import getLogger
logger = getLogger(__name__)
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
self, self,
hidden_size, config,
num_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
rotary_emb, rotary_emb,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.config = config
self.num_heads = num_heads self.hidden_size = config.hidden_size
self.head_dim = hidden_size // num_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if self.head_dim * num_heads != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {self.num_heads})."
) )
self.qkv_proj = qkv_proj if self.config.pretraining_tp > 1:
raise NotImplementedError(f"pretraining_tp of 2 or more is currently not supported.")
if len(qkv_proj) == 1:
self.qkv_mode = 'qkv'
self.qkv_proj = qkv_proj[0]
elif len(qkv_proj) == 2:
self.qkv_mode = 'q,kv'
self.q_proj = qkv_proj[0]
self.kv_proj = qkv_proj[1]
self.o_proj = o_proj self.o_proj = o_proj
self.rotary_emb = rotary_emb self.rotary_emb = rotary_emb
@ -39,9 +53,9 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
def forward( def forward(
self, self,
hidden_states, hidden_states,
past_key_value=None,
attention_mask=None, attention_mask=None,
position_ids=None, position_ids=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
use_cache=False, use_cache=False,
**kwargs **kwargs
@ -49,14 +63,18 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states) if self.qkv_mode == 'qkv':
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
elif self.qkv_mode == 'q,kv':
query_states = self.q_proj(hidden_states)
kv_states = self.kv_proj(hidden_states)
key_states, value_states = torch.split(kv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
@ -79,6 +97,10 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if compare_pytorch_version("v2.0.0", op="eq"): if compare_pytorch_version("v2.0.0", op="eq"):
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
query_states, query_states,
@ -93,7 +115,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}" f" {attn_weights.size()}"
) )
@ -103,7 +125,6 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@ -142,7 +163,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
""" """
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama) QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False
for name, m in model.named_modules(): for name, m in model.named_modules():
if not isinstance(m, LlamaAttention): if not isinstance(m, LlamaAttention):
continue continue
@ -151,41 +178,61 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
k_proj = m.k_proj k_proj = m.k_proj
v_proj = m.v_proj v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) if m.num_heads == m.num_key_value_heads:
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama": if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
qkv_layers = [qkv_layer]
else: else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) qweights = torch.cat([k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([k_proj.qzeros, v_proj.qzeros], dim=1)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None scales = torch.cat([k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
g_idx = None
else:
g_idx = torch.cat([k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = ( qlinear_args = (
q_proj.bits, k_proj.bits,
q_proj.group_size, k_proj.group_size,
q_proj.infeatures, k_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False, True if q_proj.bias is not None else False,
) )
qlinear_kwargs = {"trainable": trainable} qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton: if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs) kv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_layer.qweight = qweights kv_layer.qweight = qweights
qkv_layer.qzeros = qzeros kv_layer.qzeros = qzeros
qkv_layer.scales = scales kv_layer.scales = scales
qkv_layer.g_idx = g_idx kv_layer.g_idx = g_idx
qkv_layer.bias = bias kv_layer.bias = bias
qkv_layers = [q_proj, kv_layer]
attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) attn = cls(m.config, qkv_layers, m.o_proj, m.rotary_emb)
if '.' in name: if '.' in name:
parent_name = name.rsplit('.', 1)[0] parent_name = name.rsplit('.', 1)[0]
@ -197,6 +244,5 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
child_name = name child_name = name
setattr(parent, child_name, attn) setattr(parent, child_name, attn)
return True
__all__ = ["FusedLlamaAttentionForQuantizedModel"]