394 lines
12 KiB
Python
394 lines
12 KiB
Python
import math
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import xformers.ops as xop
|
|
from vllm import pos_encoding_ops as vllm_pos_encoding_ops
|
|
from xformers.ops.fmha.attn_bias import LowerTriangularMask, LowerTriangularMaskWithTensorBias
|
|
|
|
|
|
POTENTIAL_KV_CACHE_NAMES = (
|
|
"past_key_value",
|
|
"layer_past",
|
|
"kv_cache"
|
|
)
|
|
|
|
|
|
def _try_to_get_kv_cache(**kwargs) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
|
kv_cache = None
|
|
for name in POTENTIAL_KV_CACHE_NAMES:
|
|
if name in kwargs:
|
|
return kwargs[name]
|
|
return kv_cache
|
|
|
|
|
|
def build_rope_cache(
|
|
rotary_dim: int,
|
|
max_position: int = 2048,
|
|
base: int = 10000,
|
|
device: torch.device = torch.device("cuda:0"),
|
|
dtype: torch.dtype = torch.float16
|
|
): # TODO: support multiple scaling strategies
|
|
inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
|
|
t = torch.arange(max_position, device=device, dtype=dtype)
|
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
cos = freqs.cos()
|
|
sin = freqs.sin()
|
|
cache = torch.cat((cos, sin), dim=-1)
|
|
|
|
return cache
|
|
|
|
|
|
def build_alibi_slopes(
|
|
num_heads: int,
|
|
device: torch.device = torch.device("cuda:0"),
|
|
dtype: torch.dtype = torch.float16
|
|
):
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
|
base = torch.tensor(
|
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
|
)
|
|
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
|
slopes = torch.pow(base, powers)
|
|
|
|
if closest_power_of_2 != num_heads:
|
|
extra_base = torch.tensor(
|
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
|
)
|
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
|
|
slopes = slopes.to(dtype)
|
|
|
|
return slopes
|
|
|
|
|
|
def attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_ops: Optional[xop.AttentionOp] = (xop.fmha.flash.FwOp(), None),
|
|
attention_bias: Optional[xop.AttentionBias] = None,
|
|
p: float = 0.0,
|
|
scale: Optional[float] = None
|
|
):
|
|
if value.shape[2] != query.shape[2]:
|
|
# MQA expand
|
|
if value.shape[2] == 1:
|
|
pass # TODO
|
|
# GQA reshape
|
|
else:
|
|
original_shape = value.shape
|
|
pass # TODO
|
|
|
|
return xop.memory_efficient_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
attn_bias=attention_bias,
|
|
p=p,
|
|
scale=scale,
|
|
op=attention_ops
|
|
)
|
|
|
|
|
|
class FusedAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
qkv_proj: nn.Linear,
|
|
out_proj: nn.Linear,
|
|
num_query_heads: int,
|
|
num_key_heads: int,
|
|
num_value_heads: int,
|
|
attn_dropout: float = 0.0,
|
|
resid_dropout: float = 0.0,
|
|
scale: Optional[float] = None,
|
|
attention_ops: Optional[xop.AttentionOp] = None,
|
|
outputs_handler: Optional[Callable] = None,
|
|
training: bool = False,
|
|
):
|
|
super(FusedAttention, self).__init__()
|
|
|
|
self.qkv_proj = qkv_proj
|
|
self.out_proj = out_proj
|
|
|
|
self.num_query_heads = num_query_heads
|
|
self.num_key_heads = num_key_heads
|
|
self.num_value_heads = num_value_heads
|
|
|
|
self.attn_dropout = attn_dropout if training else 0.0
|
|
self.scale = scale
|
|
|
|
self.attention_ops = attention_ops
|
|
|
|
self.outputs_handler = outputs_handler
|
|
|
|
self.resid_dropout = nn.Dropout(resid_dropout if training else 0.0)
|
|
|
|
def _build_attn_bias(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
) -> Optional[xop.AttentionBias]:
|
|
return None
|
|
|
|
def _attn(
|
|
self,
|
|
batch_size: int,
|
|
seq_len: int,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
attention_bias: Optional[xop.AttentionBias] = None,
|
|
use_cache: bool = False,
|
|
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
):
|
|
q = q.view(batch_size, seq_len, self.num_query_heads, -1).transpose(1, 2)
|
|
k = k.view(batch_size, seq_len, self.num_key_heads, -1).transpose(1, 2)
|
|
v = v.view(batch_size, seq_len, self.num_value_heads, -1).transpose(1, 2)
|
|
|
|
if kv_cache is not None:
|
|
k_cache, v_cache = kv_cache
|
|
k = torch.cat((k_cache, k), dim=2)
|
|
v = torch.cat((v_cache, v), dim=2)
|
|
|
|
present = None
|
|
if use_cache:
|
|
present = (k, v)
|
|
|
|
attn_out = attention(
|
|
query=q.transpose(1, 2),
|
|
key=k.transpose(1, 2),
|
|
value=v.transpose(1, 2),
|
|
attention_ops=self.attention_ops,
|
|
attention_bias=attention_bias,
|
|
p=self.attn_dropout,
|
|
scale=self.scale
|
|
).view(batch_size, seq_len, -1)
|
|
|
|
return attn_out, present
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
):
|
|
bsz, seq_len = hidden_states.shape[:2]
|
|
use_cache = kwargs.get("use_cache", False)
|
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
|
|
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
|
|
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask)
|
|
attn_out, present = self._attn(
|
|
bsz,
|
|
seq_len,
|
|
q,
|
|
k,
|
|
v,
|
|
attn_bias,
|
|
use_cache,
|
|
kv_cache
|
|
)
|
|
|
|
out = self.out_proj(attn_out)
|
|
out = self.resid_dropout(out)
|
|
|
|
outputs = (out, present, None)
|
|
if self.outputs_handler:
|
|
outputs = self.outputs_handler(*outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class FusedAttentionWithRoPE(FusedAttention):
|
|
def __init__(
|
|
self,
|
|
qkv_proj: nn.Linear,
|
|
out_proj: nn.Linear,
|
|
cos_sin_cache: torch.Tensor,
|
|
num_query_heads: int,
|
|
num_key_heads: int,
|
|
num_value_heads: int,
|
|
attn_dropout: float = 0.0,
|
|
resid_dropout: float = 0.0,
|
|
scale: Optional[float] = None,
|
|
attention_ops: Optional[xop.AttentionOp] = None,
|
|
outputs_handler: Optional[Callable] = None,
|
|
training: bool = False,
|
|
):
|
|
super(FusedAttentionWithRoPE, self).__init__(
|
|
qkv_proj=qkv_proj,
|
|
out_proj=out_proj,
|
|
num_query_heads=num_query_heads,
|
|
num_key_heads=num_key_heads,
|
|
num_value_heads=num_value_heads,
|
|
attn_dropout=attn_dropout,
|
|
resid_dropout=resid_dropout,
|
|
scale=scale,
|
|
attention_ops=attention_ops,
|
|
outputs_handler=outputs_handler,
|
|
training=training
|
|
)
|
|
|
|
self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
|
|
|
|
def _build_attn_bias(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
) -> Optional[xop.AttentionBias]:
|
|
return LowerTriangularMask()
|
|
|
|
def _apply_rotary_embedding(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
position_ids: Optional[torch.Tensor] = None
|
|
):
|
|
bsz, seq_len, hidden_size = query.shape
|
|
|
|
if position_ids is not None:
|
|
query = query.view(bsz * seq_len, -1)
|
|
key = key.view(bsz * seq_len, -1)
|
|
vllm_pos_encoding_ops.rotary_embedding_neox(
|
|
position_ids.view(-1).to(query.device),
|
|
query,
|
|
key,
|
|
hidden_size // self.num_query_heads,
|
|
self.cos_sin_cache,
|
|
)
|
|
query = query.view(bsz, seq_len, -1)
|
|
key = key.view(bsz, seq_len, -1)
|
|
|
|
return query, key
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
):
|
|
bsz, seq_len = hidden_states.shape[:2]
|
|
position_ids = kwargs.get("position_ids", None)
|
|
use_cache = kwargs.get("use_cache", False)
|
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
|
|
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
|
|
|
q, k = self._apply_rotary_embedding(q, k, position_ids)
|
|
|
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
|
attn_out, present = self._attn(
|
|
bsz,
|
|
seq_len,
|
|
q,
|
|
k,
|
|
v,
|
|
attn_bias,
|
|
use_cache,
|
|
kv_cache
|
|
)
|
|
|
|
out = self.out_proj(attn_out)
|
|
out = self.resid_dropout(out)
|
|
|
|
outputs = (out, present, None)
|
|
if self.outputs_handler:
|
|
outputs = self.outputs_handler(*outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class FusedAttentionWithALiBi(FusedAttention):
|
|
def __init__(
|
|
self,
|
|
qkv_proj: nn.Linear,
|
|
out_proj: nn.Linear,
|
|
alibi_slopes: torch.Tensor,
|
|
num_query_heads: int,
|
|
num_key_heads: int,
|
|
num_value_heads: int,
|
|
attn_dropout: float = 0.0,
|
|
resid_dropout: float = 0.0,
|
|
scale: Optional[float] = None,
|
|
attention_ops: Optional[xop.AttentionOp] = None,
|
|
outputs_handler: Optional[Callable] = None,
|
|
training: bool = False,
|
|
):
|
|
super(FusedAttentionWithALiBi, self).__init__(
|
|
qkv_proj=qkv_proj,
|
|
out_proj=out_proj,
|
|
num_query_heads=num_query_heads,
|
|
num_key_heads=num_key_heads,
|
|
num_value_heads=num_value_heads,
|
|
attn_dropout=attn_dropout,
|
|
resid_dropout=resid_dropout,
|
|
scale=scale,
|
|
attention_ops=attention_ops,
|
|
outputs_handler=outputs_handler,
|
|
training=training
|
|
)
|
|
|
|
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
|
|
|
def _build_attn_bias(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
) -> Optional[xop.AttentionBias]: # adopt from vllm
|
|
bsz, seq_len = hidden_states.shape[:2]
|
|
|
|
bias = torch.arange(seq_len)
|
|
bias = bias[None, :] - bias[:, None]
|
|
bias = bias.to(hidden_states.device)
|
|
|
|
# When using custom attention bias, xformers requires the bias to
|
|
# be sliced from a tensor whose length is a multiple of 8.
|
|
padded_len = (seq_len + 7) // 8 * 8
|
|
bias = torch.empty(
|
|
self.num_query_heads,
|
|
padded_len,
|
|
padded_len,
|
|
device=self.alibi_slopes.device,
|
|
)[:, :seq_len, :seq_len].copy_(bias)
|
|
bias.mul_(self.alibi_slopes[:, None, None])
|
|
bias = LowerTriangularMaskWithTensorBias(bias.unsqueeze(0).repeat(bsz, 1, 1, 1))
|
|
|
|
return bias
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
):
|
|
bsz, seq_len = hidden_states.shape[:2]
|
|
use_cache = kwargs.get("use_cache", False)
|
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
|
|
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
|
|
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
|
attn_out, present = self._attn(
|
|
bsz,
|
|
seq_len,
|
|
q,
|
|
k,
|
|
v,
|
|
attn_bias,
|
|
use_cache,
|
|
kv_cache
|
|
)
|
|
|
|
out = self.out_proj(attn_out)
|
|
out = self.resid_dropout(out)
|
|
|
|
outputs = (out, present, None)
|
|
if self.outputs_handler:
|
|
outputs = self.outputs_handler(*outputs)
|
|
|
|
return outputs
|