Compare commits

...
Sign in to create a new pull request.

36 commits

Author SHA1 Message Date
PanQiWei
d95661b250 improve code format 2023-08-13 16:14:15 +08:00
PanQiWei
0a04d3fb2a explicit set "base" value 2023-08-13 16:14:01 +08:00
PanQiWei
7c2ec905a6 extrac rope logic into a single method for better override in child class 2023-08-13 16:13:44 +08:00
PanQiWei
b1c64d9269 add baichuan model attention fusion logic 2023-08-11 19:12:43 +08:00
PanQiWei
eeee7b344f add trust_remote_code argument in tokenizer init 2023-08-11 19:12:30 +08:00
PanQiWei
8fedbbf82d using transformers gptj rope implementation 2023-08-11 18:26:23 +08:00
PanQiWei
fdb8c4500a extend to support qlinear_exllama's fusion 2023-08-11 14:52:26 +08:00
PanQiWei
43b9a5cd0a fix pass in wrong argument 2023-08-10 15:49:52 +08:00
PanQiWei
efe47aafe5 prevent potential import error 2023-08-10 15:36:54 +08:00
PanQiWei
3d09cf36d7 fix syntax error 2023-08-10 15:36:21 +08:00
潘其威(William)
beab695c5b
Merge branch 'main' into xformers_integration 2023-08-10 15:27:11 +08:00
PanQiWei
edc5b72da4 using pytorch backend rope 2023-08-09 10:20:58 +08:00
PanQiWei
26dc6852fe support inherit one of the three fused attention class and customize attn_bias building logic 2023-08-07 18:59:04 +08:00
PanQiWei
d73ed1cfc2 freeze triton version to 2.0.0 2023-08-07 18:55:59 +08:00
PanQiWei
e5f874e5af add fused attention injection logic to llama 2023-08-07 13:45:37 +08:00
PanQiWei
700406e6b6 add 'inject_fused_attention' and 'inject_fused_mlp' action flags 2023-08-06 18:50:28 +08:00
PanQiWei
2092a80b81 keep attn_op as what it is when passed in 2023-08-06 18:38:25 +08:00
PanQiWei
4aea0aef39 update benchmark script 2023-08-06 18:37:23 +08:00
PanQiWei
1f9717af7f change classes default values 2023-08-06 18:24:23 +08:00
PanQiWei
7a70bcf6d8 doing 'memory_efficient_fusion' in __init__ 2023-08-06 17:23:57 +08:00
PanQiWei
57c3e5b7d5 change CL 2023-08-06 16:24:59 +08:00
PanQiWei
01ce32553e remove unnecessary lines 2023-08-06 16:24:44 +08:00
PanQiWei
677409e2fe fix using wrong attribute 2023-08-06 16:23:19 +08:00
PanQiWei
9155ef3038 fix using wrong attribute 2023-08-06 15:37:11 +08:00
PanQiWei
df24da5797 mark fused ops injection as experiment features 2023-08-06 15:05:28 +08:00
PanQiWei
ab6faa6496 implement gptj attention and mlp fused ops injection logic 2023-08-06 14:55:06 +08:00
PanQiWei
f67b512cee add 'training' argument 2023-08-06 14:54:34 +08:00
PanQiWei
bacac399d3 abandon change trainable mode after model is loaded; support specify customized AttentionOp 2023-08-06 14:15:43 +08:00
PanQiWei
c71f5cdf12 add '_fuse_attention' and '_fuse_mlp' abstract static methods 2023-08-06 12:45:08 +08:00
PanQiWei
0fcfddda90 rename 'inject_to_model' to 'convert_to_torch_linear' 2023-08-06 12:09:16 +08:00
PanQiWei
2826729e73 use pytorch normal forward logic when trainable is True 2023-08-06 11:44:29 +08:00
PanQiWei
801610367d Merge branch 'main' into xformers_integration 2023-08-05 18:02:00 +08:00
PanQiWei
7d0909160c add fused MLPs 2023-08-04 20:03:16 +08:00
PanQiWei
8b19122775 add fused attentions 2023-08-04 19:11:43 +08:00
PanQiWei
cd8a674002 add FusedGeneralQuantLinear 2023-08-04 19:10:32 +08:00
PanQiWei
116d8267d7 update requirements 2023-08-04 19:10:05 +08:00
18 changed files with 1197 additions and 145 deletions

View file

@ -17,11 +17,11 @@ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers
from transformers.modeling_utils import no_init_weights
from xformers.ops.fmha import AttentionOp
from ._const import *
from ._utils import *
from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ..utils.import_utils import (
@ -89,7 +89,6 @@ class BaseQuantizeConfig(PushToHubMixin):
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
with open(resolved_config_file, "r", encoding="utf-8") as f:
return cls(**json.load(f))
@ -114,9 +113,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
inside_layer_modules: List[List[str]] = None
lm_head_name: str = "lm_head"
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
def __init__(
self,
model: PreTrainedModel,
@ -210,14 +206,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
if use_triton and not TRITON_AVAILABLE:
logger.warning("triton is not installed, reset use_triton to False")
logger.warning("Triton is not installed, reset use_triton to False")
use_triton = False
device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu":
logger.info(f"truly offloading {name} to cpu with hook.")
logger.info(f"Truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0)
@ -437,10 +433,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, **kwargs):
def generate(self, *args, **kwargs):
"""shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
with torch.no_grad(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
@ -489,8 +485,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
"""
if (self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)) and save_dir is None:
raise ValueError("Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading.")
if (
self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)
) and save_dir is None:
raise ValueError(
"Quantized model should be saved first, or you can provide save_dir to "
"make sure model is saved to local disk before uploading."
)
if save_dir is not None:
logger.info(f"Saving model to {save_dir}")
@ -517,16 +518,30 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
repo_type="model",
)
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
def save_quantized(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
raise TypeError("Can only save quantized model, please execute .quantize() method first.")
if self.injected_fused_attention or self.injected_fused_mlp:
raise TypeError(
"At least one of attention modules and mlp modules are injected with fused ops, "
"please disable 'inject_fused_attention' and 'inject_fused_mlp' at model loading stage, "
"and don't call ._fuse_attention() and ._fuse_mlp() methods before calling this method."
)
self.model.to(CPU)
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
@ -546,13 +561,23 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and "
f"an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
logger.warning(
f"After converting safetensors_metadata keys to strings, the key "
f"'{new_key}' is duplicated. Ensure that all your metadata keys are "
f"strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). "
f"Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
@ -576,9 +601,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.quantize_config.model_name_or_path = save_dir
self.quantize_config.model_file_base_name = model_base_name
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
def save_pretrained(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs
):
"""alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
logger.warning("You are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
@classmethod
@ -672,9 +703,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
model.seqlen = model_config[key]
break
else:
logger.warning("can't get model's sequence length from model config, will set to 4096.")
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
return cls(model, False, quantize_config)
@ -688,8 +718,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
torch_dtype: torch.dtype = torch.float16,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
inject_fused_attention: bool = False,
inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
@ -697,6 +727,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
attn_op: Optional[AttentionOp] = None,
disable_exllama: bool = False,
**kwargs
):
@ -727,7 +758,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
}
if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.")
logger.warning("Triton is not installed, reset use_triton to False")
use_triton = False
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
logger.warning(
@ -746,9 +777,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"2. You are using pytorch without CUDA support.\n"
"3. CUDA and nvcc are not installed in your device."
)
if any([inject_fused_attention, inject_fused_mlp]) and trainable:
logger.warning(
"Neither fused attention nor fused mlp is tested under trainable mode, "
"this may cause unexpected behavior or lead to error if you are training "
"a quantized model with fused ops, please consider disabling 'inject_fused_attention' "
"and 'inject_fused_mlp'."
)
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs
)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
@ -795,9 +837,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if not disable_exllama and trainable:
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
disable_exllama = True
elif not use_triton and trainable:
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
logger.warning(
"QuantLinear with cuda backend not support trainable mode yet, will switch to pytorch backend, "
"this may cause very slow inference speed, disable trainable if you are not training model."
)
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs):
@ -881,7 +925,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
)
model = simple_dispatch_model(model, device_map)
# == step4: set seqlen == #
# == step4: post init model == #
# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
# == step5: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
@ -890,50 +938,62 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
model.seqlen = model_config[key]
break
else:
logger.warning("can't get model's sequence length from model config, will set to 4096.")
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
# == step5: (optional) inject optimized module == #
# == step6: (optional) inject optimized module == #
if inject_fused_attention:
if cls.fused_attn_module_type is None:
try:
cls._fuse_attention(model, attn_op, trainable)
except NotImplementedError:
inject_fused_attention = False
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
else:
cls.fused_attn_module_type.inject_to_model(
model,
use_triton=use_triton,
group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
logger.warning(
f"{cls.__name__} doesn't support fusing attention yet, will skip inject fused attention."
)
except:
logger.error(
f"Inject fused attention failed, you can set 'inject_fused_attention' to False to "
f"bypass the error for now and report it on github."
)
raise
if inject_fused_mlp:
if cls.fused_mlp_module_type is None:
try:
cls._fuse_mlp(model, trainable)
except NotImplementedError:
inject_fused_mlp = False
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
else:
cls.fused_mlp_module_type.inject_to_model(
model,
use_triton=use_triton
logger.warning(
f"{cls.__name__} doesn't support fusing mlp yet, will skip inject fused mlp."
)
except:
logger.error(
f"Inject fused mlp failed, you can set 'inject_fused_mlp' to False to "
f"bypass the error for now and report it on github."
)
raise
if inject_fused_attention or inject_fused_mlp:
logger.warning(
"You are using at least one of 'inject_fused_attention' and 'inject_fused_mlp' "
"modes, which are now marked as experimental features, feel free to open an issue "
"or ask any question about those two features on github if you encounter unexpected "
"behaviors and errors."
)
# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
model.eval()
# == step6: (optional) warmup triton == #
# == step7: (optional) warmup triton == #
if use_triton and warmup_triton:
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
QuantLinear.warmup(model, seqlen=model.seqlen)
cls.warmup_triton(model)
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
# == step7: make model compatible with peft
cls.make_sure_compatible_with_peft(
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits
# == step8: convert all QuantLinear to sub-class of torch.nn.Linear
# note if _fuse_attention() and _fuse_mlp() is implemented,
# all QuantLinear will be converted to sub-class of torch.nn.Linear at injection stage
GeneralQuantLinear.convert_to_torch_linear(
model,
dynamically_import_QuantLinear(
use_triton,
quantize_config.desc_act,
quantize_config.group_size,
quantize_config.bits,
disable_exllama
)
)
return cls(
@ -942,39 +1002,35 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantize_config,
is_triton_backend=use_triton,
injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and use_triton,
injected_fused_mlp=inject_fused_mlp,
trainable=trainable
)
def warmup_triton(self, enabled: bool = True):
@staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
raise NotImplementedError()
@staticmethod
def _fuse_mlp(
model: PreTrainedModel,
trainable: bool = False
) -> None:
raise NotImplementedError()
@staticmethod
def warmup_triton(model: nn.Module, enabled: bool = True) -> None:
if not enabled:
return
if not TRITON_AVAILABLE:
logger.warning(f"triton is not available, skip warmup stage directly.")
logger.warning(f"Triton is not available, skip warmup stage directly.")
return
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
if self.fused_mlp_module_type is not None:
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
def enable_trainable_mode(self, enabled: bool = True):
if not self.is_triton_backend and enabled:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
for n, m in self.model.named_modules():
if hasattr(m, "trainable"):
setattr(m, "trainable", enabled)
def disable_trainable_mode(self):
self.enable_trainable_mode(enabled=False)
@staticmethod
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
GeneralQuantLinear.inject_to_model(
model,
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
)
QuantLinear.warmup(model, seqlen=model.seqlen)
def __getattr__(self, item):
try:

View file

@ -1,9 +1,5 @@
from packaging.version import parse as parse_version
from torch import device
from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
@ -20,9 +16,8 @@ SUPPORTED_MODELS = [
"RefinedWeb",
"baichuan",
"internlm",
"llama",
"qwen",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]

View file

@ -1,6 +1,8 @@
from inspect import signature
from typing import Dict, Optional, Union
from xformers.ops.fmha import AttentionOp
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM
@ -81,6 +83,7 @@ class AutoGPTQForCausalLM:
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
attn_op: Optional[AttentionOp] = None,
disable_exllama: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
@ -121,6 +124,7 @@ class AutoGPTQForCausalLM:
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
attn_op=attn_op,
disable_exllama=disable_exllama,
**keywords
)

View file

@ -1,4 +1,19 @@
from copy import deepcopy
from typing import Optional
import xformers.ops as xop
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from xformers.ops.fmha import AttentionOp
from ._base import *
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
class BaiChuanFusedAttentionWithRope(FusedAttentionWithRoPE):
pass
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
@ -12,5 +27,49 @@ class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"]
]
@staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.num_attention_heads
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.model.layers
rope_cache = build_rope_cache(
rotary_dim=model_config.hidden_size // num_heads,
max_position=model_config.max_position_embeddings,
device=model.device,
dtype=model.dtype
)
for layer in layers:
old_attn = layer.self_attn
attn_device = old_attn.W_pack.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear(old_attn.W_pack)
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
new_attn = BaiChuanFusedAttentionWithRope(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=0.0,
resid_dropout=0.0,
scale=scale,
attention_ops=attn_op,
outputs_handler=(lambda x, y, z: (x, z, y)),
training=trainable
)
layer.self_attn = new_attn
del old_attn
empty_cache()
__all__ = ["BaiChuanGPTQForCausalLM"]

View file

@ -1,4 +1,4 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
from ._base import BaseGPTQForCausalLM
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
@ -14,4 +14,5 @@ class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.c_proj"]
]
__all__ = ["GPTBigCodeGPTQForCausalLM"]

View file

@ -1,5 +1,136 @@
from copy import deepcopy
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xop
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb
from xformers.ops.fmha import AttentionOp
from ._base import *
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.attention import FusedAttention
from ..nn_modules.fused_modules.mlp import FusedMLP
class GPTJFusedAttention(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
embed_positions: torch.Tensor,
rotary_dim: Optional[int],
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(GPTJFusedAttention, self).__init__(
qkv_proj,
out_proj,
num_query_heads,
num_key_heads,
num_value_heads,
attn_dropout,
resid_dropout,
scale,
attention_ops,
outputs_handler,
training
)
self.embed_positions = embed_positions
self.rotary_dim = rotary_dim
def _get_embed_positions(self, position_ids: torch.Tensor):
return self.embed_positions.repeat(position_ids.shape[0], 1, 1)
def _apply_rotary(
self,
query: torch.Tensor,
key: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_len = key.shape[:2]
dtype = query.dtype
query = query.view(bsz, seq_len, self.num_query_heads, -1).to(dtype=torch.float)
key = key.view(bsz, seq_len, self.num_key_heads, -1).to(dtype=torch.float)
embed_positions = self._get_embed_positions(position_ids)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
return query.view(bsz, seq_len, -1).to(dtype=dtype), key.view(bsz, seq_len, -1).to(dtype=dtype)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return xop.LowerTriangularMask()
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
if position_ids is not None:
q, k = self._apply_rotary(q, k, position_ids)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if layer_past is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
layer_past
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
@ -13,7 +144,75 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_out"]
]
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel
@staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.n_head
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.transformer.h
for layer in layers:
old_attn = layer.attn
device = old_attn.q_proj.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear.fuse(
old_attn.q_proj,
old_attn.k_proj,
old_attn.v_proj
)
new_out_proj = FusedGeneralQuantLinear(old_attn.out_proj)
new_attn = GPTJFusedAttention(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
embed_positions=old_attn.embed_positions.to(device),
rotary_dim=old_attn.rotary_dim,
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=model_config.attn_pdrop,
resid_dropout=model_config.resid_pdrop,
scale=scale,
attention_ops=attn_op,
outputs_handler=None,
training=trainable
)
layer.attn = new_attn
del old_attn
empty_cache()
@staticmethod
def _fuse_mlp(
model: PreTrainedModel,
trainable: bool = False
) -> None:
model_config = model.config
act = ACT2FN[model_config.activation_function]
out_dropout = model_config.resid_pdrop
layers = model.transformer.h
for layer in layers:
old_mlp = layer.mlp
new_mlp = FusedMLP(
input_proj=FusedGeneralQuantLinear(old_mlp.fc_in),
out_proj=FusedGeneralQuantLinear(old_mlp.fc_out),
activation=act,
in_dropout=0.0,
out_dropout=out_dropout,
training=trainable,
residual=False
)
layer.mlp = new_mlp
del old_mlp
empty_cache()
__all__ = ["GPTJGPTQForCausalLM"]

View file

@ -1,16 +1,18 @@
from logging import getLogger
from copy import deepcopy
from typing import Optional
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from xformers.ops.fmha import AttentionOp
from ._base import *
from ..utils.import_utils import compare_transformers_version
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None
logger = getLogger(__name__)
class LlamaFusedAttentionWithRoPE(FusedAttentionWithRoPE):
pass
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
@ -24,8 +26,61 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"]
]
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
@staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.num_attention_heads
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.model.layers
rope_cache = build_rope_cache(
rotary_dim=model_config.hidden_size // num_heads,
max_position=model_config.max_position_embeddings,
base=10000,
device=model.device,
dtype=model.dtype
)
for layer in layers:
old_attn = layer.self_attn
attn_device = old_attn.q_proj.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear.fuse(
old_attn.q_proj,
old_attn.k_proj,
old_attn.v_proj
)
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
new_attn = LlamaFusedAttentionWithRoPE(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=0.0,
resid_dropout=0.0,
scale=scale,
attention_ops=attn_op,
outputs_handler=(lambda x, y, z: (x, z, y)),
training=trainable
)
layer.self_attn = new_attn
del old_attn
empty_cache()
# @staticmethod
# def _fuse_mlp(
# model: PreTrainedModel,
# trainable: bool = False
# ) -> None:
# pass
__all__ = ["LlamaGPTQForCausalLM"]

View file

@ -0,0 +1,394 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xop
from vllm import pos_encoding_ops as vllm_pos_encoding_ops
from xformers.ops.fmha.attn_bias import LowerTriangularMask, LowerTriangularMaskWithTensorBias
POTENTIAL_KV_CACHE_NAMES = (
"past_key_value",
"layer_past",
"kv_cache"
)
def _try_to_get_kv_cache(**kwargs) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
kv_cache = None
for name in POTENTIAL_KV_CACHE_NAMES:
if name in kwargs:
return kwargs[name]
return kv_cache
def build_rope_cache(
rotary_dim: int,
max_position: int = 2048,
base: int = 10000,
device: torch.device = torch.device("cuda:0"),
dtype: torch.dtype = torch.float16
): # TODO: support multiple scaling strategies
inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
t = torch.arange(max_position, device=device, dtype=dtype)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def build_alibi_slopes(
num_heads: int,
device: torch.device = torch.device("cuda:0"),
dtype: torch.dtype = torch.float16
):
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
slopes = slopes.to(dtype)
return slopes
def attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_ops: Optional[xop.AttentionOp] = (xop.fmha.flash.FwOp(), None),
attention_bias: Optional[xop.AttentionBias] = None,
p: float = 0.0,
scale: Optional[float] = None
):
if value.shape[2] != query.shape[2]:
# MQA expand
if value.shape[2] == 1:
pass # TODO
# GQA reshape
else:
original_shape = value.shape
pass # TODO
return xop.memory_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attention_bias,
p=p,
scale=scale,
op=attention_ops
)
class FusedAttention(nn.Module):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttention, self).__init__()
self.qkv_proj = qkv_proj
self.out_proj = out_proj
self.num_query_heads = num_query_heads
self.num_key_heads = num_key_heads
self.num_value_heads = num_value_heads
self.attn_dropout = attn_dropout if training else 0.0
self.scale = scale
self.attention_ops = attention_ops
self.outputs_handler = outputs_handler
self.resid_dropout = nn.Dropout(resid_dropout if training else 0.0)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return None
def _attn(
self,
batch_size: int,
seq_len: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_bias: Optional[xop.AttentionBias] = None,
use_cache: bool = False,
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
):
q = q.view(batch_size, seq_len, self.num_query_heads, -1).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_key_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_value_heads, -1).transpose(1, 2)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat((k_cache, k), dim=2)
v = torch.cat((v_cache, v), dim=2)
present = None
if use_cache:
present = (k, v)
attn_out = attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
attention_ops=self.attention_ops,
attention_bias=attention_bias,
p=self.attn_dropout,
scale=self.scale
).view(batch_size, seq_len, -1)
return attn_out, present
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
attn_bias = self._build_attn_bias(hidden_states, attention_mask)
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class FusedAttentionWithRoPE(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
cos_sin_cache: torch.Tensor,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttentionWithRoPE, self).__init__(
qkv_proj=qkv_proj,
out_proj=out_proj,
num_query_heads=num_query_heads,
num_key_heads=num_key_heads,
num_value_heads=num_value_heads,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
scale=scale,
attention_ops=attention_ops,
outputs_handler=outputs_handler,
training=training
)
self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return LowerTriangularMask()
def _apply_rotary_embedding(
self,
query: torch.Tensor,
key: torch.Tensor,
position_ids: Optional[torch.Tensor] = None
):
bsz, seq_len, hidden_size = query.shape
if position_ids is not None:
query = query.view(bsz * seq_len, -1)
key = key.view(bsz * seq_len, -1)
vllm_pos_encoding_ops.rotary_embedding_neox(
position_ids.view(-1).to(query.device),
query,
key,
hidden_size // self.num_query_heads,
self.cos_sin_cache,
)
query = query.view(bsz, seq_len, -1)
key = key.view(bsz, seq_len, -1)
return query, key
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
position_ids = kwargs.get("position_ids", None)
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
q, k = self._apply_rotary_embedding(q, k, position_ids)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class FusedAttentionWithALiBi(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
alibi_slopes: torch.Tensor,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttentionWithALiBi, self).__init__(
qkv_proj=qkv_proj,
out_proj=out_proj,
num_query_heads=num_query_heads,
num_key_heads=num_key_heads,
num_value_heads=num_value_heads,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
scale=scale,
attention_ops=attention_ops,
outputs_handler=outputs_handler,
training=training
)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]: # adopt from vllm
bsz, seq_len = hidden_states.shape[:2]
bias = torch.arange(seq_len)
bias = bias[None, :] - bias[:, None]
bias = bias.to(hidden_states.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
bias = torch.empty(
self.num_query_heads,
padded_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :seq_len, :seq_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
bias = LowerTriangularMaskWithTensorBias(bias.unsqueeze(0).repeat(bsz, 1, 1, 1))
return bias
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs

View file

@ -0,0 +1,97 @@
import torch
from ..qlinear import GeneralQuantLinear
from ..qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
from ..qlinear.qlinear_cuda_old import QuantLinear as OldCudaQuantLinear
try:
from ..qlinear.qlinear_triton import QuantLinear as TritonQuantLinear
except:
TritonQuantLinear = None
try:
from ..qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
except:
ExllamaQuantLinear = None
class FusedGeneralQuantLinear(GeneralQuantLinear):
def __init__(self, quant_linear_module):
super(FusedGeneralQuantLinear, self).__init__(quant_linear_module)
@classmethod
def fuse(
cls,
q_proj,
k_proj=None,
v_proj=None,
):
qweights, qzeros, scales, g_idx, bias = [], [], [], [], []
outfeatures = 0
for module in [q_proj, k_proj, v_proj]:
if module is not None:
qweights.append(module.qweight)
qzeros.append(module.qzeros)
scales.append(module.scales)
g_idx.append(module.g_idx)
bias.append(module.bias)
outfeatures += module.outfeatures
if bias[0] is None:
bias = None
if len(qweights) > 1:
qweights = torch.cat(qweights, dim=1)
qzeros = torch.cat(qzeros, dim=1)
scales = torch.cat(scales, dim=1)
g_idx = torch.cat(g_idx, dim=0)
if bias is not None:
bias = torch.cat(bias, dim=0)
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
outfeatures,
bias is not None
)
qlinear_kwargs = {"trainable": q_proj.trainable}
if isinstance(q_proj, (OldCudaQuantLinear, CudaQuantLinear)):
qlinear_kwargs["kernel_switch_threshold"] = q_proj.kernel_switch_threshold
if isinstance(q_proj, OldCudaQuantLinear):
qlinear_kwargs["use_cuda_fp16"] = q_proj.use_cuda_fp16
QuantLinear = OldCudaQuantLinear
else:
QuantLinear = CudaQuantLinear
elif isinstance(q_proj, TritonQuantLinear):
QuantLinear = TritonQuantLinear
else:
QuantLinear = ExllamaQuantLinear
fused_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
fused_proj.qweight = qweights
fused_proj.qzeros = qzeros
fused_proj.scales = scales
fused_proj.g_idx = g_idx
fused_proj.bias = bias
if isinstance(q_proj, ExllamaQuantLinear):
if not hasattr(q_proj, "_use_act_order"):
raise AttributeError(
"q_proj doesn't have attribute _use_act_order, please execute "
"auto_gptq.modeling._utils.autogptq_post_init function before "
"fuse quant linears."
)
if q_proj._use_act_order:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError(
"Exllama kernel does not support layer fusion with act-order. "
"Please either use inject_fused_attention=False or disable_exllama=True."
)
fused_proj._use_act_order = q_proj._use_act_order
fused_proj.g_idx = None
fused_proj.post_init()
del q_proj, k_proj, v_proj
return cls(fused_proj)

View file

@ -0,0 +1,168 @@
from functools import partial
from typing import Union, Callable, Optional
import torch
import torch.nn as nn
from functorch.compile import memory_efficient_fusion
from torch.nn import functional as F
def act_dropout(
hidden_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
def dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
dropout: float = 0.0
):
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
def act_dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
class NVFusedActDropoutRes(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
residual: bool = False,
is_cuda: bool = False
):
super(NVFusedActDropoutRes, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None and residual:
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
elif activation is not None:
fn = partial(act_dropout, activation=activation, dropout=dropout)
elif residual:
fn = partial(dropout_res, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
self.residual = residual
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
if self.residual:
return self.fn(hidden_states, residual)
else:
return self.fn(hidden_states)
class FusedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False,
residual: bool = False
):
super(FusedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
is_cuda = input_proj.weight.data.device.type == "cuda"
self.input_proj = input_proj
self.fused_op1 = NVFusedActDropoutRes(
activation=activation,
dropout=in_dropout if training else 0.0,
residual=False,
is_cuda=is_cuda
)
self.out_proj = out_proj
self.fused_op2 = NVFusedActDropoutRes(
activation=None,
dropout=out_dropout if training else 0.0,
residual=residual,
is_cuda=is_cuda
)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
def gated_act_dropout(
gate_states: torch.Tensor,
up_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(gate_states) * up_states
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
class NVFusedGatedActDropout(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
is_cuda: bool = False
):
super(NVFusedGatedActDropout, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None:
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
def forward(self, gate_states: torch.Tensor, up_states):
if isinstance(self.fn, nn.Dropout):
return self.fn(gate_states * up_states)
return self.fn(gate_states, up_states)
class FusedGatedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False
):
super(FusedGatedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
self.input_proj = input_proj
self.fused_op = NVFusedGatedActDropout(
activation=activation,
dropout=in_dropout if training else 0.0,
is_cuda=input_proj.weight.data.device.type == "cuda"
)
self.out_proj = out_proj
self.out_dropout = nn.Dropout(out_dropout)
self.intermediate_size = self.input_proj.out_features // 2
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.input_proj(hidden_states)
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))

View file

@ -6,7 +6,7 @@ class GeneralQuantLinear(nn.Linear):
super().__init__(
in_features=quant_linear_module.infeatures,
out_features=quant_linear_module.outfeatures,
bias=True
bias=quant_linear_module.bias is not None
)
self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures
@ -18,28 +18,47 @@ class GeneralQuantLinear(nn.Linear):
self.weight.data = quant_linear_module.qweight
self.register_buffer('qweight', quant_linear_module.qweight)
if quant_linear_module.bias is not None:
self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False
self.bias.requires_grad = False
self.register_buffer('qzeros', quant_linear_module.qzeros)
self.register_buffer('scales', quant_linear_module.scales)
self.register_buffer('g_idx', quant_linear_module.g_idx)
# arg of qlinear_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "kernel_switch_threshold"):
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda"):
self.autogptq_cuda = quant_linear_module.autogptq_cuda
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "half_indim"):
self.half_indim = quant_linear_module.half_indim
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "use_cuda_fp16"):
self.use_cuda_fp16 = quant_linear_module.use_cuda_fp16
# args of qlinear_exllama
if hasattr(quant_linear_module, "_use_act_order"):
self._use_act_order = quant_linear_module._use_act_order
# arg of qlinaer_exllama
if hasattr(quant_linear_module, "width"):
self.width = quant_linear_module.width
# arg of qlinear_exllama
if hasattr(quant_linear_module, "q4"):
self.q4 = quant_linear_module.q4
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward
@classmethod
def inject_to_model(cls, model, target_module_type):
def convert_to_torch_linear(cls, model: nn.Module, target_module_type: "QuantLinear"):
for name, m in model.named_modules():
if not isinstance(m, target_module_type):
continue

View file

@ -36,8 +36,6 @@ class QuantLinear(nn.Module):
global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures
self.outfeatures = outfeatures
@ -198,7 +196,7 @@ class QuantLinear(nn.Module):
x = x.reshape(-1, x.shape[-1])
if self.autogptq_cuda_available and (
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
):
) and not self.trainable:
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
if self.bits == 2:
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)

View file

@ -36,8 +36,7 @@ class QuantLinear(nn.Module):
global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
@ -198,7 +197,7 @@ class QuantLinear(nn.Module):
x = x.reshape(-1, x.shape[-1])
if self.autogptq_cuda_available is True and (
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
):
) and not self.trainable:
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
if self.use_cuda_fp16:
x = x.half()

View file

@ -354,7 +354,7 @@ def get_gptq_peft_model(
train_mode: bool = False
):
if train_mode and not model.trainable:
model.enable_trainable_mode()
raise TypeError("model is not trainable, please load model with 'trainable=True'")
if train_mode and not peft_config:
raise ValueError("peft_config not specified when in train mode.")
if not train_mode and not model_id:

View file

@ -87,30 +87,29 @@ def load_data(data_path, tokenizer, n_samples, max_new_tokens):
outputs = examples["output"]
prompts = []
texts = []
outs = []
input_ids = []
attention_mask = []
for istr, inp, opt in zip(instructions, inputs, outputs):
if inp:
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
text = prompt + opt
else:
prompt = f"Instruction:\n{istr}\nOutput:\n"
text = prompt + opt
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
continue
tokenized_data = tokenizer(text)
tokenized_data = tokenizer(prompt)
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
prompts.append(prompt)
texts.append(text)
outs.append(opt)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt": prompts
"prompt": prompts,
"output": outs
}
dataset = Dataset.from_generator(dummy_gen)
@ -236,9 +235,9 @@ def main():
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--inject_fused_attention", action="store_true")
parser.add_argument("--inject_fused_mlp", action="store_true")
parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--no_inject_fused_attention", action="store_true")
parser.add_argument("--no_inject_fused_mlp", action="store_true")
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
parser.add_argument("--cpu_max_memory", type=int, default=None)
@ -277,8 +276,8 @@ def main():
use_triton=args.use_triton,
use_safetensors=args.use_safetensors,
use_fast_tokenizer=args.use_fast_tokenizer,
inject_fused_attention=not args.no_inject_fused_attention,
inject_fused_mlp=not args.no_inject_fused_mlp,
inject_fused_attention=args.inject_fused_attention,
inject_fused_mlp=args.inject_fused_mlp,
disable_exllama=args.disable_exllama
)
end = time.time()
@ -289,7 +288,7 @@ def main():
if args.use_triton:
logger.info("warmup triton, this may take a while.")
model.warmup_triton()
model.warmup_triton(model)
logger.info("loading data")
examples = load_data(

View file

@ -37,12 +37,18 @@ if __name__ == "__main__":
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--inject_fused_attention", action="store_true", help="Whether to inject fused attention")
parser.add_argument("--inject_fused_mlp", action="store_true", help="Whether to inject fused mlp")
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name,
use_fast=args.use_fast_tokenizer,
trust_remote_code=args.trust_remote_code
)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id
@ -68,8 +74,8 @@ if __name__ == "__main__":
model_basename=args.model_basename,
use_safetensors=args.use_safetensors,
trust_remote_code=args.trust_remote_code,
inject_fused_mlp=False,
inject_fused_attention=False,
inject_fused_mlp=args.inject_fused_mlp,
inject_fused_attention=args.inject_fused_attention,
disable_exllama=args.disable_exllama
)
else:

View file

@ -68,7 +68,10 @@ requirements = [
"datasets",
"numpy",
"rouge",
"torch>=1.13.0",
"torch>=2.0.1",
"functorch",
"xformers>=0.0.20",
"vllm",
"safetensors",
"transformers>=4.31.0",
"peft"