Compare commits

...
Sign in to create a new pull request.

36 commits

Author SHA1 Message Date
PanQiWei
d95661b250 improve code format 2023-08-13 16:14:15 +08:00
PanQiWei
0a04d3fb2a explicit set "base" value 2023-08-13 16:14:01 +08:00
PanQiWei
7c2ec905a6 extrac rope logic into a single method for better override in child class 2023-08-13 16:13:44 +08:00
PanQiWei
b1c64d9269 add baichuan model attention fusion logic 2023-08-11 19:12:43 +08:00
PanQiWei
eeee7b344f add trust_remote_code argument in tokenizer init 2023-08-11 19:12:30 +08:00
PanQiWei
8fedbbf82d using transformers gptj rope implementation 2023-08-11 18:26:23 +08:00
PanQiWei
fdb8c4500a extend to support qlinear_exllama's fusion 2023-08-11 14:52:26 +08:00
PanQiWei
43b9a5cd0a fix pass in wrong argument 2023-08-10 15:49:52 +08:00
PanQiWei
efe47aafe5 prevent potential import error 2023-08-10 15:36:54 +08:00
PanQiWei
3d09cf36d7 fix syntax error 2023-08-10 15:36:21 +08:00
潘其威(William)
beab695c5b
Merge branch 'main' into xformers_integration 2023-08-10 15:27:11 +08:00
PanQiWei
edc5b72da4 using pytorch backend rope 2023-08-09 10:20:58 +08:00
PanQiWei
26dc6852fe support inherit one of the three fused attention class and customize attn_bias building logic 2023-08-07 18:59:04 +08:00
PanQiWei
d73ed1cfc2 freeze triton version to 2.0.0 2023-08-07 18:55:59 +08:00
PanQiWei
e5f874e5af add fused attention injection logic to llama 2023-08-07 13:45:37 +08:00
PanQiWei
700406e6b6 add 'inject_fused_attention' and 'inject_fused_mlp' action flags 2023-08-06 18:50:28 +08:00
PanQiWei
2092a80b81 keep attn_op as what it is when passed in 2023-08-06 18:38:25 +08:00
PanQiWei
4aea0aef39 update benchmark script 2023-08-06 18:37:23 +08:00
PanQiWei
1f9717af7f change classes default values 2023-08-06 18:24:23 +08:00
PanQiWei
7a70bcf6d8 doing 'memory_efficient_fusion' in __init__ 2023-08-06 17:23:57 +08:00
PanQiWei
57c3e5b7d5 change CL 2023-08-06 16:24:59 +08:00
PanQiWei
01ce32553e remove unnecessary lines 2023-08-06 16:24:44 +08:00
PanQiWei
677409e2fe fix using wrong attribute 2023-08-06 16:23:19 +08:00
PanQiWei
9155ef3038 fix using wrong attribute 2023-08-06 15:37:11 +08:00
PanQiWei
df24da5797 mark fused ops injection as experiment features 2023-08-06 15:05:28 +08:00
PanQiWei
ab6faa6496 implement gptj attention and mlp fused ops injection logic 2023-08-06 14:55:06 +08:00
PanQiWei
f67b512cee add 'training' argument 2023-08-06 14:54:34 +08:00
PanQiWei
bacac399d3 abandon change trainable mode after model is loaded; support specify customized AttentionOp 2023-08-06 14:15:43 +08:00
PanQiWei
c71f5cdf12 add '_fuse_attention' and '_fuse_mlp' abstract static methods 2023-08-06 12:45:08 +08:00
PanQiWei
0fcfddda90 rename 'inject_to_model' to 'convert_to_torch_linear' 2023-08-06 12:09:16 +08:00
PanQiWei
2826729e73 use pytorch normal forward logic when trainable is True 2023-08-06 11:44:29 +08:00
PanQiWei
801610367d Merge branch 'main' into xformers_integration 2023-08-05 18:02:00 +08:00
PanQiWei
7d0909160c add fused MLPs 2023-08-04 20:03:16 +08:00
PanQiWei
8b19122775 add fused attentions 2023-08-04 19:11:43 +08:00
PanQiWei
cd8a674002 add FusedGeneralQuantLinear 2023-08-04 19:10:32 +08:00
PanQiWei
116d8267d7 update requirements 2023-08-04 19:10:05 +08:00
18 changed files with 1197 additions and 145 deletions

View file

@ -17,11 +17,11 @@ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers from transformers.utils.generic import ContextManagers
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights
from xformers.ops.fmha import AttentionOp
from ._const import * from ._const import *
from ._utils import * from ._utils import *
from ..nn_modules.qlinear import GeneralQuantLinear from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..quantization import GPTQ from ..quantization import GPTQ
from ..utils.data_utils import collate_data from ..utils.data_utils import collate_data
from ..utils.import_utils import ( from ..utils.import_utils import (
@ -73,7 +73,7 @@ class BaseQuantizeConfig(PushToHubMixin):
quantize_config_filename = "quantize_config.json" quantize_config_filename = "quantize_config.json"
if os.path.isdir(save_dir): # Local if os.path.isdir(save_dir): # Local
resolved_config_file = join(save_dir, quantize_config_filename) resolved_config_file = join(save_dir, quantize_config_filename)
else: # Remote else: # Remote
resolved_config_file = cached_file( resolved_config_file = cached_file(
save_dir, save_dir,
quantize_config_filename, quantize_config_filename,
@ -89,10 +89,9 @@ class BaseQuantizeConfig(PushToHubMixin):
_raise_exceptions_for_connection_errors=False, _raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash, _commit_hash=commit_hash,
) )
with open(resolved_config_file, "r", encoding="utf-8") as f: with open(resolved_config_file, "r", encoding="utf-8") as f:
return cls(**json.load(f)) return cls(**json.load(f))
def to_dict(self): def to_dict(self):
return { return {
"bits": self.bits, "bits": self.bits,
@ -114,9 +113,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
inside_layer_modules: List[List[str]] = None inside_layer_modules: List[List[str]] = None
lm_head_name: str = "lm_head" lm_head_name: str = "lm_head"
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel,
@ -210,14 +206,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if self.quantized: if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.") raise EnvironmentError("can't execute quantize because the model is quantized.")
if use_triton and not TRITON_AVAILABLE: if use_triton and not TRITON_AVAILABLE:
logger.warning("triton is not installed, reset use_triton to False") logger.warning("Triton is not installed, reset use_triton to False")
use_triton = False use_triton = False
device_map = self.hf_device_map device_map = self.hf_device_map
if device_map: if device_map:
for name, device in device_map.items(): for name, device in device_map.items():
if device == "cpu": if device == "cpu":
logger.info(f"truly offloading {name} to cpu with hook.") logger.info(f"Truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name) module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True) remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0) accelerate.cpu_offload_with_hook(module, CUDA_0)
@ -437,10 +433,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
def generate(self, **kwargs): def generate(self, *args, **kwargs):
"""shortcut for model.generate""" """shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): with torch.no_grad(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs) return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs): def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation""" """shortcut for model.prepare_inputs_for_generation"""
@ -489,9 +485,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
create_pr (`bool`, *optional*, defaults to `False`): create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit. Whether or not to create a PR with the uploaded files or directly commit.
""" """
if (self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)) and save_dir is None: if (
raise ValueError("Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading.") self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)
) and save_dir is None:
raise ValueError(
"Quantized model should be saved first, or you can provide save_dir to "
"make sure model is saved to local disk before uploading."
)
if save_dir is not None: if save_dir is not None:
logger.info(f"Saving model to {save_dir}") logger.info(f"Saving model to {save_dir}")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
@ -517,16 +518,30 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
repo_type="model", repo_type="model",
) )
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None): def save_quantized(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None
):
"""save quantized model and configs to local disk""" """save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
if not self.quantized: if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.") raise TypeError("Can only save quantized model, please execute .quantize() method first.")
if self.injected_fused_attention or self.injected_fused_mlp:
raise TypeError(
"At least one of attention modules and mlp modules are injected with fused ops, "
"please disable 'inject_fused_attention' and 'inject_fused_mlp' at model loading stage, "
"and don't call ._fuse_attention() and ._fuse_mlp() methods before calling this method."
)
self.model.to(CPU) self.model.to(CPU)
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
if use_safetensors: if use_safetensors:
model_save_name = model_base_name + ".safetensors" model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
@ -546,13 +561,23 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
new_key = str(key) new_key = str(key)
new_value = str(value) new_value = str(value)
except Exception as e: except Exception as e:
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}") raise TypeError(
f"safetensors_metadata: both keys and values must be strings and "
f"an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata: if new_key in new_safetensors_metadata:
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.") logger.warning(
f"After converting safetensors_metadata keys to strings, the key "
f"'{new_key}' is duplicated. Ensure that all your metadata keys are "
f"strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata safetensors_metadata = new_safetensors_metadata
if converted_keys: if converted_keys:
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}") logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). "
f"Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata # Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError # otherwise it raises an OSError
@ -576,9 +601,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.quantize_config.model_name_or_path = save_dir self.quantize_config.model_name_or_path = save_dir
self.quantize_config.model_file_base_name = model_base_name self.quantize_config.model_file_base_name = model_base_name
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs): def save_pretrained(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs
):
"""alias of save_quantized""" """alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") logger.warning("You are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
@classmethod @classmethod
@ -672,9 +703,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
model.seqlen = model_config[key] model.seqlen = model_config[key]
break break
else: else:
logger.warning("can't get model's sequence length from model config, will set to 4096.") logger.warning("Can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096 model.seqlen = 4096
model.eval()
return cls(model, False, quantize_config) return cls(model, False, quantize_config)
@ -688,8 +718,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
use_triton: bool = False, use_triton: bool = False,
torch_dtype: torch.dtype = torch.float16, torch_dtype: torch.dtype = torch.float16,
inject_fused_attention: bool = True, inject_fused_attention: bool = False,
inject_fused_mlp: bool = True, inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True, use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None, quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None, model_basename: Optional[str] = None,
@ -697,11 +727,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
trust_remote_code: bool = False, trust_remote_code: bool = False,
warmup_triton: bool = False, warmup_triton: bool = False,
trainable: bool = False, trainable: bool = False,
attn_op: Optional[AttentionOp] = None,
disable_exllama: bool = False, disable_exllama: bool = False,
**kwargs **kwargs
): ):
"""load quantized model from local disk""" """load quantized model from local disk"""
# Parameters related to loading from Hugging Face Hub # Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
@ -725,9 +756,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"_raise_exceptions_for_missing_entries": False, "_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash, "_commit_hash": commit_hash,
} }
if use_triton and not TRITON_AVAILABLE: if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.") logger.warning("Triton is not installed, reset use_triton to False")
use_triton = False use_triton = False
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE: if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
logger.warning( logger.warning(
@ -746,22 +777,33 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"2. You are using pytorch without CUDA support.\n" "2. You are using pytorch without CUDA support.\n"
"3. CUDA and nvcc are not installed in your device." "3. CUDA and nvcc are not installed in your device."
) )
if any([inject_fused_attention, inject_fused_mlp]) and trainable:
logger.warning(
"Neither fused attention nor fused mlp is tested under trainable mode, "
"this may cause unexpected behavior or lead to error if you are training "
"a quantized model with fused ops, please consider disabling 'inject_fused_attention' "
"and 'inject_fused_mlp'."
)
# == step1: prepare configs and file names == # # == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs) config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs
)
if config.model_type not in SUPPORTED_MODELS: if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
if quantize_config is None: if quantize_config is None:
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs) quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
if model_basename is None: if model_basename is None:
if quantize_config.model_file_base_name: if quantize_config.model_file_base_name:
model_basename = quantize_config.model_file_base_name model_basename = quantize_config.model_file_base_name
else: else:
model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g" model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g"
quantize_config.model_name_or_path = model_name_or_path quantize_config.model_name_or_path = model_name_or_path
quantize_config.model_file_base_name = model_basename quantize_config.model_file_base_name = model_basename
@ -786,18 +828,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
if resolved_archive_file is not None: if resolved_archive_file is not None:
break break
if resolved_archive_file is None: # Could not find a model file to use if resolved_archive_file is None: # Could not find a model file to use
raise FileNotFoundError(f"Could not find model in {model_name_or_path}") raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
model_save_name = resolved_archive_file model_save_name = resolved_archive_file
if not disable_exllama and trainable: if not disable_exllama and trainable:
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.") logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
disable_exllama = True disable_exllama = True
elif not use_triton and trainable: elif not use_triton and trainable:
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.") logger.warning(
"QuantLinear with cuda backend not support trainable mode yet, will switch to pytorch backend, "
"this may cause very slow inference speed, disable trainable if you are not training model."
)
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == # # == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs): def skip(*args, **kwargs):
@ -881,7 +925,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
) )
model = simple_dispatch_model(model, device_map) model = simple_dispatch_model(model, device_map)
# == step4: set seqlen == # # == step4: post init model == #
# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
# == step5: set seqlen == #
model_config = model.config.to_dict() model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]): if any([k in model_config for k in seq_len_keys]):
@ -890,50 +938,62 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
model.seqlen = model_config[key] model.seqlen = model_config[key]
break break
else: else:
logger.warning("can't get model's sequence length from model config, will set to 4096.") logger.warning("Can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096 model.seqlen = 4096
# == step5: (optional) inject optimized module == # # == step6: (optional) inject optimized module == #
if inject_fused_attention: if inject_fused_attention:
if cls.fused_attn_module_type is None: try:
cls._fuse_attention(model, attn_op, trainable)
except NotImplementedError:
inject_fused_attention = False inject_fused_attention = False
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.") logger.warning(
else: f"{cls.__name__} doesn't support fusing attention yet, will skip inject fused attention."
cls.fused_attn_module_type.inject_to_model(
model,
use_triton=use_triton,
group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
) )
except:
logger.error(
f"Inject fused attention failed, you can set 'inject_fused_attention' to False to "
f"bypass the error for now and report it on github."
)
raise
if inject_fused_mlp: if inject_fused_mlp:
if cls.fused_mlp_module_type is None: try:
cls._fuse_mlp(model, trainable)
except NotImplementedError:
inject_fused_mlp = False inject_fused_mlp = False
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.") logger.warning(
else: f"{cls.__name__} doesn't support fusing mlp yet, will skip inject fused mlp."
cls.fused_mlp_module_type.inject_to_model(
model,
use_triton=use_triton
) )
except:
logger.error(
f"Inject fused mlp failed, you can set 'inject_fused_mlp' to False to "
f"bypass the error for now and report it on github."
)
raise
if inject_fused_attention or inject_fused_mlp:
logger.warning(
"You are using at least one of 'inject_fused_attention' and 'inject_fused_mlp' "
"modes, which are now marked as experimental features, feel free to open an issue "
"or ask any question about those two features on github if you encounter unexpected "
"behaviors and errors."
)
# Any post-initialization that require device information, for example buffers initialization on device. # == step7: (optional) warmup triton == #
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
model.eval()
# == step6: (optional) warmup triton == #
if use_triton and warmup_triton: if use_triton and warmup_triton:
from ..nn_modules.qlinear.qlinear_triton import QuantLinear cls.warmup_triton(model)
QuantLinear.warmup(model, seqlen=model.seqlen)
if inject_fused_mlp and cls.fused_mlp_module_type is not None: # == step8: convert all QuantLinear to sub-class of torch.nn.Linear
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen) # note if _fuse_attention() and _fuse_mlp() is implemented,
# all QuantLinear will be converted to sub-class of torch.nn.Linear at injection stage
# == step7: make model compatible with peft GeneralQuantLinear.convert_to_torch_linear(
cls.make_sure_compatible_with_peft( model,
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits dynamically_import_QuantLinear(
use_triton,
quantize_config.desc_act,
quantize_config.group_size,
quantize_config.bits,
disable_exllama
)
) )
return cls( return cls(
@ -942,39 +1002,35 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantize_config, quantize_config,
is_triton_backend=use_triton, is_triton_backend=use_triton,
injected_fused_attention=inject_fused_attention, injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and use_triton, injected_fused_mlp=inject_fused_mlp,
trainable=trainable trainable=trainable
) )
def warmup_triton(self, enabled: bool = True): @staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
raise NotImplementedError()
@staticmethod
def _fuse_mlp(
model: PreTrainedModel,
trainable: bool = False
) -> None:
raise NotImplementedError()
@staticmethod
def warmup_triton(model: nn.Module, enabled: bool = True) -> None:
if not enabled: if not enabled:
return return
if not TRITON_AVAILABLE: if not TRITON_AVAILABLE:
logger.warning(f"triton is not available, skip warmup stage directly.") logger.warning(f"Triton is not available, skip warmup stage directly.")
return return
from ..nn_modules.qlinear.qlinear_triton import QuantLinear from ..nn_modules.qlinear.qlinear_triton import QuantLinear
QuantLinear.warmup(self.model, seqlen=self.model.seqlen) QuantLinear.warmup(model, seqlen=model.seqlen)
if self.fused_mlp_module_type is not None:
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
def enable_trainable_mode(self, enabled: bool = True):
if not self.is_triton_backend and enabled:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
for n, m in self.model.named_modules():
if hasattr(m, "trainable"):
setattr(m, "trainable", enabled)
def disable_trainable_mode(self):
self.enable_trainable_mode(enabled=False)
@staticmethod
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
GeneralQuantLinear.inject_to_model(
model,
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
)
def __getattr__(self, item): def __getattr__(self, item):
try: try:

View file

@ -1,9 +1,5 @@
from packaging.version import parse as parse_version
from torch import device from torch import device
from ..utils.import_utils import compare_transformers_version
CPU = device("cpu") CPU = device("cpu")
CUDA_0 = device("cuda:0") CUDA_0 = device("cuda:0")
@ -20,9 +16,8 @@ SUPPORTED_MODELS = [
"RefinedWeb", "RefinedWeb",
"baichuan", "baichuan",
"internlm", "internlm",
"llama",
"qwen", "qwen",
] ]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"] __all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]

View file

@ -1,6 +1,8 @@
from inspect import signature from inspect import signature
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from xformers.ops.fmha import AttentionOp
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM from .bloom import BloomGPTQForCausalLM
@ -81,6 +83,7 @@ class AutoGPTQForCausalLM:
trust_remote_code: bool = False, trust_remote_code: bool = False,
warmup_triton: bool = False, warmup_triton: bool = False,
trainable: bool = False, trainable: bool = False,
attn_op: Optional[AttentionOp] = None,
disable_exllama: bool = False, disable_exllama: bool = False,
**kwargs **kwargs
) -> BaseGPTQForCausalLM: ) -> BaseGPTQForCausalLM:
@ -121,6 +124,7 @@ class AutoGPTQForCausalLM:
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton, warmup_triton=warmup_triton,
trainable=trainable, trainable=trainable,
attn_op=attn_op,
disable_exllama=disable_exllama, disable_exllama=disable_exllama,
**keywords **keywords
) )

View file

@ -1,4 +1,19 @@
from copy import deepcopy
from typing import Optional
import xformers.ops as xop
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from xformers.ops.fmha import AttentionOp
from ._base import * from ._base import *
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
class BaiChuanFusedAttentionWithRope(FusedAttentionWithRoPE):
pass
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM): class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
@ -12,5 +27,49 @@ class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"] ["mlp.down_proj"]
] ]
@staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.num_attention_heads
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.model.layers
rope_cache = build_rope_cache(
rotary_dim=model_config.hidden_size // num_heads,
max_position=model_config.max_position_embeddings,
device=model.device,
dtype=model.dtype
)
for layer in layers:
old_attn = layer.self_attn
attn_device = old_attn.W_pack.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear(old_attn.W_pack)
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
new_attn = BaiChuanFusedAttentionWithRope(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=0.0,
resid_dropout=0.0,
scale=scale,
attention_ops=attn_op,
outputs_handler=(lambda x, y, z: (x, z, y)),
training=trainable
)
layer.self_attn = new_attn
del old_attn
empty_cache()
__all__ = ["BaiChuanGPTQForCausalLM"] __all__ = ["BaiChuanGPTQForCausalLM"]

View file

@ -1,4 +1,4 @@
from auto_gptq.modeling import BaseGPTQForCausalLM from ._base import BaseGPTQForCausalLM
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM): class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
@ -14,4 +14,5 @@ class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.c_proj"] ["mlp.c_proj"]
] ]
__all__ = ["GPTBigCodeGPTQForCausalLM"]
__all__ = ["GPTBigCodeGPTQForCausalLM"]

View file

@ -1,5 +1,136 @@
from copy import deepcopy
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xop
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb
from xformers.ops.fmha import AttentionOp
from ._base import * from ._base import *
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.attention import FusedAttention
from ..nn_modules.fused_modules.mlp import FusedMLP
class GPTJFusedAttention(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
embed_positions: torch.Tensor,
rotary_dim: Optional[int],
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(GPTJFusedAttention, self).__init__(
qkv_proj,
out_proj,
num_query_heads,
num_key_heads,
num_value_heads,
attn_dropout,
resid_dropout,
scale,
attention_ops,
outputs_handler,
training
)
self.embed_positions = embed_positions
self.rotary_dim = rotary_dim
def _get_embed_positions(self, position_ids: torch.Tensor):
return self.embed_positions.repeat(position_ids.shape[0], 1, 1)
def _apply_rotary(
self,
query: torch.Tensor,
key: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_len = key.shape[:2]
dtype = query.dtype
query = query.view(bsz, seq_len, self.num_query_heads, -1).to(dtype=torch.float)
key = key.view(bsz, seq_len, self.num_key_heads, -1).to(dtype=torch.float)
embed_positions = self._get_embed_positions(position_ids)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
return query.view(bsz, seq_len, -1).to(dtype=dtype), key.view(bsz, seq_len, -1).to(dtype=dtype)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return xop.LowerTriangularMask()
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
if position_ids is not None:
q, k = self._apply_rotary(q, k, position_ids)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if layer_past is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
layer_past
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM): class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
@ -13,7 +144,75 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_out"] ["mlp.fc_out"]
] ]
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel @staticmethod
def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.n_head
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.transformer.h
for layer in layers:
old_attn = layer.attn
device = old_attn.q_proj.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear.fuse(
old_attn.q_proj,
old_attn.k_proj,
old_attn.v_proj
)
new_out_proj = FusedGeneralQuantLinear(old_attn.out_proj)
new_attn = GPTJFusedAttention(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
embed_positions=old_attn.embed_positions.to(device),
rotary_dim=old_attn.rotary_dim,
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=model_config.attn_pdrop,
resid_dropout=model_config.resid_pdrop,
scale=scale,
attention_ops=attn_op,
outputs_handler=None,
training=trainable
)
layer.attn = new_attn
del old_attn
empty_cache()
@staticmethod
def _fuse_mlp(
model: PreTrainedModel,
trainable: bool = False
) -> None:
model_config = model.config
act = ACT2FN[model_config.activation_function]
out_dropout = model_config.resid_pdrop
layers = model.transformer.h
for layer in layers:
old_mlp = layer.mlp
new_mlp = FusedMLP(
input_proj=FusedGeneralQuantLinear(old_mlp.fc_in),
out_proj=FusedGeneralQuantLinear(old_mlp.fc_out),
activation=act,
in_dropout=0.0,
out_dropout=out_dropout,
training=trainable,
residual=False
)
layer.mlp = new_mlp
del old_mlp
empty_cache()
__all__ = ["GPTJGPTQForCausalLM"] __all__ = ["GPTJGPTQForCausalLM"]

View file

@ -1,16 +1,18 @@
from logging import getLogger from copy import deepcopy
from typing import Optional
from torch.cuda import empty_cache
from transformers import PreTrainedModel
from xformers.ops.fmha import AttentionOp
from ._base import * from ._base import *
from ..utils.import_utils import compare_transformers_version from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None
logger = getLogger(__name__) class LlamaFusedAttentionWithRoPE(FusedAttentionWithRoPE):
pass
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM): class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
@ -24,8 +26,61 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"] ["mlp.down_proj"]
] ]
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel @staticmethod
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel def _fuse_attention(
model: PreTrainedModel,
attn_op: Optional[AttentionOp] = None,
trainable: bool = False
) -> None:
model_config = model.config
num_heads = model_config.num_attention_heads
scale = (model_config.hidden_size // num_heads) ** -0.5
layers = model.model.layers
rope_cache = build_rope_cache(
rotary_dim=model_config.hidden_size // num_heads,
max_position=model_config.max_position_embeddings,
base=10000,
device=model.device,
dtype=model.dtype
)
for layer in layers:
old_attn = layer.self_attn
attn_device = old_attn.q_proj.qweight.data.device
new_qkv_proj = FusedGeneralQuantLinear.fuse(
old_attn.q_proj,
old_attn.k_proj,
old_attn.v_proj
)
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
new_attn = LlamaFusedAttentionWithRoPE(
qkv_proj=new_qkv_proj,
out_proj=new_out_proj,
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
num_query_heads=num_heads,
num_key_heads=num_heads,
num_value_heads=num_heads,
attn_dropout=0.0,
resid_dropout=0.0,
scale=scale,
attention_ops=attn_op,
outputs_handler=(lambda x, y, z: (x, z, y)),
training=trainable
)
layer.self_attn = new_attn
del old_attn
empty_cache()
# @staticmethod
# def _fuse_mlp(
# model: PreTrainedModel,
# trainable: bool = False
# ) -> None:
# pass
__all__ = ["LlamaGPTQForCausalLM"] __all__ = ["LlamaGPTQForCausalLM"]

View file

@ -0,0 +1,394 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xop
from vllm import pos_encoding_ops as vllm_pos_encoding_ops
from xformers.ops.fmha.attn_bias import LowerTriangularMask, LowerTriangularMaskWithTensorBias
POTENTIAL_KV_CACHE_NAMES = (
"past_key_value",
"layer_past",
"kv_cache"
)
def _try_to_get_kv_cache(**kwargs) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
kv_cache = None
for name in POTENTIAL_KV_CACHE_NAMES:
if name in kwargs:
return kwargs[name]
return kv_cache
def build_rope_cache(
rotary_dim: int,
max_position: int = 2048,
base: int = 10000,
device: torch.device = torch.device("cuda:0"),
dtype: torch.dtype = torch.float16
): # TODO: support multiple scaling strategies
inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
t = torch.arange(max_position, device=device, dtype=dtype)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def build_alibi_slopes(
num_heads: int,
device: torch.device = torch.device("cuda:0"),
dtype: torch.dtype = torch.float16
):
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
slopes = slopes.to(dtype)
return slopes
def attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_ops: Optional[xop.AttentionOp] = (xop.fmha.flash.FwOp(), None),
attention_bias: Optional[xop.AttentionBias] = None,
p: float = 0.0,
scale: Optional[float] = None
):
if value.shape[2] != query.shape[2]:
# MQA expand
if value.shape[2] == 1:
pass # TODO
# GQA reshape
else:
original_shape = value.shape
pass # TODO
return xop.memory_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attention_bias,
p=p,
scale=scale,
op=attention_ops
)
class FusedAttention(nn.Module):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttention, self).__init__()
self.qkv_proj = qkv_proj
self.out_proj = out_proj
self.num_query_heads = num_query_heads
self.num_key_heads = num_key_heads
self.num_value_heads = num_value_heads
self.attn_dropout = attn_dropout if training else 0.0
self.scale = scale
self.attention_ops = attention_ops
self.outputs_handler = outputs_handler
self.resid_dropout = nn.Dropout(resid_dropout if training else 0.0)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return None
def _attn(
self,
batch_size: int,
seq_len: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_bias: Optional[xop.AttentionBias] = None,
use_cache: bool = False,
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
):
q = q.view(batch_size, seq_len, self.num_query_heads, -1).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_key_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_value_heads, -1).transpose(1, 2)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat((k_cache, k), dim=2)
v = torch.cat((v_cache, v), dim=2)
present = None
if use_cache:
present = (k, v)
attn_out = attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
attention_ops=self.attention_ops,
attention_bias=attention_bias,
p=self.attn_dropout,
scale=self.scale
).view(batch_size, seq_len, -1)
return attn_out, present
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
attn_bias = self._build_attn_bias(hidden_states, attention_mask)
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class FusedAttentionWithRoPE(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
cos_sin_cache: torch.Tensor,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttentionWithRoPE, self).__init__(
qkv_proj=qkv_proj,
out_proj=out_proj,
num_query_heads=num_query_heads,
num_key_heads=num_key_heads,
num_value_heads=num_value_heads,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
scale=scale,
attention_ops=attention_ops,
outputs_handler=outputs_handler,
training=training
)
self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]:
return LowerTriangularMask()
def _apply_rotary_embedding(
self,
query: torch.Tensor,
key: torch.Tensor,
position_ids: Optional[torch.Tensor] = None
):
bsz, seq_len, hidden_size = query.shape
if position_ids is not None:
query = query.view(bsz * seq_len, -1)
key = key.view(bsz * seq_len, -1)
vllm_pos_encoding_ops.rotary_embedding_neox(
position_ids.view(-1).to(query.device),
query,
key,
hidden_size // self.num_query_heads,
self.cos_sin_cache,
)
query = query.view(bsz, seq_len, -1)
key = key.view(bsz, seq_len, -1)
return query, key
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
position_ids = kwargs.get("position_ids", None)
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
q, k = self._apply_rotary_embedding(q, k, position_ids)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs
class FusedAttentionWithALiBi(FusedAttention):
def __init__(
self,
qkv_proj: nn.Linear,
out_proj: nn.Linear,
alibi_slopes: torch.Tensor,
num_query_heads: int,
num_key_heads: int,
num_value_heads: int,
attn_dropout: float = 0.0,
resid_dropout: float = 0.0,
scale: Optional[float] = None,
attention_ops: Optional[xop.AttentionOp] = None,
outputs_handler: Optional[Callable] = None,
training: bool = False,
):
super(FusedAttentionWithALiBi, self).__init__(
qkv_proj=qkv_proj,
out_proj=out_proj,
num_query_heads=num_query_heads,
num_key_heads=num_key_heads,
num_value_heads=num_value_heads,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
scale=scale,
attention_ops=attention_ops,
outputs_handler=outputs_handler,
training=training
)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
def _build_attn_bias(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Optional[xop.AttentionBias]: # adopt from vllm
bsz, seq_len = hidden_states.shape[:2]
bias = torch.arange(seq_len)
bias = bias[None, :] - bias[:, None]
bias = bias.to(hidden_states.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
bias = torch.empty(
self.num_query_heads,
padded_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :seq_len, :seq_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
bias = LowerTriangularMaskWithTensorBias(bias.unsqueeze(0).repeat(bsz, 1, 1, 1))
return bias
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
):
bsz, seq_len = hidden_states.shape[:2]
use_cache = kwargs.get("use_cache", False)
kv_cache = _try_to_get_kv_cache(**kwargs)
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
attn_out, present = self._attn(
bsz,
seq_len,
q,
k,
v,
attn_bias,
use_cache,
kv_cache
)
out = self.out_proj(attn_out)
out = self.resid_dropout(out)
outputs = (out, present, None)
if self.outputs_handler:
outputs = self.outputs_handler(*outputs)
return outputs

View file

@ -0,0 +1,97 @@
import torch
from ..qlinear import GeneralQuantLinear
from ..qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
from ..qlinear.qlinear_cuda_old import QuantLinear as OldCudaQuantLinear
try:
from ..qlinear.qlinear_triton import QuantLinear as TritonQuantLinear
except:
TritonQuantLinear = None
try:
from ..qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
except:
ExllamaQuantLinear = None
class FusedGeneralQuantLinear(GeneralQuantLinear):
def __init__(self, quant_linear_module):
super(FusedGeneralQuantLinear, self).__init__(quant_linear_module)
@classmethod
def fuse(
cls,
q_proj,
k_proj=None,
v_proj=None,
):
qweights, qzeros, scales, g_idx, bias = [], [], [], [], []
outfeatures = 0
for module in [q_proj, k_proj, v_proj]:
if module is not None:
qweights.append(module.qweight)
qzeros.append(module.qzeros)
scales.append(module.scales)
g_idx.append(module.g_idx)
bias.append(module.bias)
outfeatures += module.outfeatures
if bias[0] is None:
bias = None
if len(qweights) > 1:
qweights = torch.cat(qweights, dim=1)
qzeros = torch.cat(qzeros, dim=1)
scales = torch.cat(scales, dim=1)
g_idx = torch.cat(g_idx, dim=0)
if bias is not None:
bias = torch.cat(bias, dim=0)
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
outfeatures,
bias is not None
)
qlinear_kwargs = {"trainable": q_proj.trainable}
if isinstance(q_proj, (OldCudaQuantLinear, CudaQuantLinear)):
qlinear_kwargs["kernel_switch_threshold"] = q_proj.kernel_switch_threshold
if isinstance(q_proj, OldCudaQuantLinear):
qlinear_kwargs["use_cuda_fp16"] = q_proj.use_cuda_fp16
QuantLinear = OldCudaQuantLinear
else:
QuantLinear = CudaQuantLinear
elif isinstance(q_proj, TritonQuantLinear):
QuantLinear = TritonQuantLinear
else:
QuantLinear = ExllamaQuantLinear
fused_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
fused_proj.qweight = qweights
fused_proj.qzeros = qzeros
fused_proj.scales = scales
fused_proj.g_idx = g_idx
fused_proj.bias = bias
if isinstance(q_proj, ExllamaQuantLinear):
if not hasattr(q_proj, "_use_act_order"):
raise AttributeError(
"q_proj doesn't have attribute _use_act_order, please execute "
"auto_gptq.modeling._utils.autogptq_post_init function before "
"fuse quant linears."
)
if q_proj._use_act_order:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError(
"Exllama kernel does not support layer fusion with act-order. "
"Please either use inject_fused_attention=False or disable_exllama=True."
)
fused_proj._use_act_order = q_proj._use_act_order
fused_proj.g_idx = None
fused_proj.post_init()
del q_proj, k_proj, v_proj
return cls(fused_proj)

View file

@ -0,0 +1,168 @@
from functools import partial
from typing import Union, Callable, Optional
import torch
import torch.nn as nn
from functorch.compile import memory_efficient_fusion
from torch.nn import functional as F
def act_dropout(
hidden_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
def dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
dropout: float = 0.0
):
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
def act_dropout_res(
hidden_states: torch.Tensor,
residual: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(hidden_states)
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
return torch.add(hidden_states, residual)
class NVFusedActDropoutRes(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
residual: bool = False,
is_cuda: bool = False
):
super(NVFusedActDropoutRes, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None and residual:
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
elif activation is not None:
fn = partial(act_dropout, activation=activation, dropout=dropout)
elif residual:
fn = partial(dropout_res, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
self.residual = residual
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
if self.residual:
return self.fn(hidden_states, residual)
else:
return self.fn(hidden_states)
class FusedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False,
residual: bool = False
):
super(FusedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
is_cuda = input_proj.weight.data.device.type == "cuda"
self.input_proj = input_proj
self.fused_op1 = NVFusedActDropoutRes(
activation=activation,
dropout=in_dropout if training else 0.0,
residual=False,
is_cuda=is_cuda
)
self.out_proj = out_proj
self.fused_op2 = NVFusedActDropoutRes(
activation=None,
dropout=out_dropout if training else 0.0,
residual=residual,
is_cuda=is_cuda
)
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
def gated_act_dropout(
gate_states: torch.Tensor,
up_states: torch.Tensor,
activation: Union[Callable, nn.Module],
dropout: float = 0.0
):
hidden_states = activation(gate_states) * up_states
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
class NVFusedGatedActDropout(nn.Module):
def __init__(
self,
activation: Optional[Union[Callable, nn.Module]] = None,
dropout: float = 0.0,
is_cuda: bool = False
):
super(NVFusedGatedActDropout, self).__init__()
fn = partial(F.dropout, p=dropout)
if activation is not None:
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
self.fn = fn
if is_cuda:
self.fn = memory_efficient_fusion(self.fn)
def forward(self, gate_states: torch.Tensor, up_states):
if isinstance(self.fn, nn.Dropout):
return self.fn(gate_states * up_states)
return self.fn(gate_states, up_states)
class FusedGatedMLP(nn.Module):
def __init__(
self,
input_proj: nn.Linear,
out_proj: nn.Linear,
activation: Optional[Union[Callable, nn.Module]] = None,
in_dropout: float = 0.0,
out_dropout: float = 0.0,
training: bool = False
):
super(FusedGatedMLP, self).__init__()
if activation is None:
activation = nn.Identity()
self.input_proj = input_proj
self.fused_op = NVFusedGatedActDropout(
activation=activation,
dropout=in_dropout if training else 0.0,
is_cuda=input_proj.weight.data.device.type == "cuda"
)
self.out_proj = out_proj
self.out_dropout = nn.Dropout(out_dropout)
self.intermediate_size = self.input_proj.out_features // 2
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.input_proj(hidden_states)
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))

View file

@ -6,7 +6,7 @@ class GeneralQuantLinear(nn.Linear):
super().__init__( super().__init__(
in_features=quant_linear_module.infeatures, in_features=quant_linear_module.infeatures,
out_features=quant_linear_module.outfeatures, out_features=quant_linear_module.outfeatures,
bias=True bias=quant_linear_module.bias is not None
) )
self.infeatures = quant_linear_module.infeatures self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures self.outfeatures = quant_linear_module.outfeatures
@ -18,28 +18,47 @@ class GeneralQuantLinear(nn.Linear):
self.weight.data = quant_linear_module.qweight self.weight.data = quant_linear_module.qweight
self.register_buffer('qweight', quant_linear_module.qweight) self.register_buffer('qweight', quant_linear_module.qweight)
self.bias.data = quant_linear_module.bias if quant_linear_module.bias is not None:
self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False
self.bias.requires_grad = False
self.register_buffer('qzeros', quant_linear_module.qzeros) self.register_buffer('qzeros', quant_linear_module.qzeros)
self.register_buffer('scales', quant_linear_module.scales) self.register_buffer('scales', quant_linear_module.scales)
self.register_buffer('g_idx', quant_linear_module.g_idx) self.register_buffer('g_idx', quant_linear_module.g_idx)
# arg of qlinear_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "wf"): if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf self.wf = quant_linear_module.wf
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "kernel_switch_threshold"): if hasattr(quant_linear_module, "kernel_switch_threshold"):
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda_available"): if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
# arg of qlinaer_cuda and qlinear_cuda_old
if hasattr(quant_linear_module, "autogptq_cuda"):
self.autogptq_cuda = quant_linear_module.autogptq_cuda
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "half_indim"):
self.half_indim = quant_linear_module.half_indim
# arg of qlinear_cuda_old
if hasattr(quant_linear_module, "use_cuda_fp16"):
self.use_cuda_fp16 = quant_linear_module.use_cuda_fp16
# args of qlinear_exllama
if hasattr(quant_linear_module, "_use_act_order"):
self._use_act_order = quant_linear_module._use_act_order
# arg of qlinaer_exllama
if hasattr(quant_linear_module, "width"):
self.width = quant_linear_module.width
# arg of qlinear_exllama
if hasattr(quant_linear_module, "q4"):
self.q4 = quant_linear_module.q4
self.trainable = quant_linear_module.trainable self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward self.forward = quant_linear_module.forward
@classmethod @classmethod
def inject_to_model(cls, model, target_module_type): def convert_to_torch_linear(cls, model: nn.Module, target_module_type: "QuantLinear"):
for name, m in model.named_modules(): for name, m in model.named_modules():
if not isinstance(m, target_module_type): if not isinstance(m, target_module_type):
continue continue

View file

@ -36,8 +36,6 @@ class QuantLinear(nn.Module):
global _autogptq_cuda_available global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]: if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.") raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures self.infeatures = infeatures
self.outfeatures = outfeatures self.outfeatures = outfeatures
@ -198,7 +196,7 @@ class QuantLinear(nn.Module):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
if self.autogptq_cuda_available and ( if self.autogptq_cuda_available and (
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
): ) and not self.trainable:
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
if self.bits == 2: if self.bits == 2:
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)

View file

@ -36,8 +36,7 @@ class QuantLinear(nn.Module):
global _autogptq_cuda_available global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]: if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.") raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures self.infeatures = infeatures
self.outfeatures = outfeatures self.outfeatures = outfeatures
self.bits = bits self.bits = bits
@ -198,7 +197,7 @@ class QuantLinear(nn.Module):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
if self.autogptq_cuda_available is True and ( if self.autogptq_cuda_available is True and (
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
): ) and not self.trainable:
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device) out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
if self.use_cuda_fp16: if self.use_cuda_fp16:
x = x.half() x = x.half()

View file

@ -354,7 +354,7 @@ def get_gptq_peft_model(
train_mode: bool = False train_mode: bool = False
): ):
if train_mode and not model.trainable: if train_mode and not model.trainable:
model.enable_trainable_mode() raise TypeError("model is not trainable, please load model with 'trainable=True'")
if train_mode and not peft_config: if train_mode and not peft_config:
raise ValueError("peft_config not specified when in train mode.") raise ValueError("peft_config not specified when in train mode.")
if not train_mode and not model_id: if not train_mode and not model_id:

View file

@ -87,30 +87,29 @@ def load_data(data_path, tokenizer, n_samples, max_new_tokens):
outputs = examples["output"] outputs = examples["output"]
prompts = [] prompts = []
texts = [] outs = []
input_ids = [] input_ids = []
attention_mask = [] attention_mask = []
for istr, inp, opt in zip(instructions, inputs, outputs): for istr, inp, opt in zip(instructions, inputs, outputs):
if inp: if inp:
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n" prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
text = prompt + opt
else: else:
prompt = f"Instruction:\n{istr}\nOutput:\n" prompt = f"Instruction:\n{istr}\nOutput:\n"
text = prompt + opt
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens: if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
continue continue
tokenized_data = tokenizer(text) tokenized_data = tokenizer(prompt)
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length]) input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length]) attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
prompts.append(prompt) prompts.append(prompt)
texts.append(text) outs.append(opt)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"prompt": prompts "prompt": prompts,
"output": outs
} }
dataset = Dataset.from_generator(dummy_gen) dataset = Dataset.from_generator(dummy_gen)
@ -236,9 +235,9 @@ def main():
parser.add_argument("--use_triton", action="store_true") parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_safetensors", action="store_true") parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true") parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--inject_fused_attention", action="store_true")
parser.add_argument("--inject_fused_mlp", action="store_true")
parser.add_argument("--disable_exllama", action="store_true") parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--no_inject_fused_attention", action="store_true")
parser.add_argument("--no_inject_fused_mlp", action="store_true")
parser.add_argument("--num_samples", type=int, default=10) parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--per_gpu_max_memory", type=int, default=None) parser.add_argument("--per_gpu_max_memory", type=int, default=None)
parser.add_argument("--cpu_max_memory", type=int, default=None) parser.add_argument("--cpu_max_memory", type=int, default=None)
@ -277,8 +276,8 @@ def main():
use_triton=args.use_triton, use_triton=args.use_triton,
use_safetensors=args.use_safetensors, use_safetensors=args.use_safetensors,
use_fast_tokenizer=args.use_fast_tokenizer, use_fast_tokenizer=args.use_fast_tokenizer,
inject_fused_attention=not args.no_inject_fused_attention, inject_fused_attention=args.inject_fused_attention,
inject_fused_mlp=not args.no_inject_fused_mlp, inject_fused_mlp=args.inject_fused_mlp,
disable_exllama=args.disable_exllama disable_exllama=args.disable_exllama
) )
end = time.time() end = time.time()
@ -289,7 +288,7 @@ def main():
if args.use_triton: if args.use_triton:
logger.info("warmup triton, this may take a while.") logger.info("warmup triton, this may take a while.")
model.warmup_triton() model.warmup_triton(model)
logger.info("loading data") logger.info("loading data")
examples = load_data( examples = load_data(

View file

@ -37,12 +37,18 @@ if __name__ == "__main__":
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file") parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code") parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--inject_fused_attention", action="store_true", help="Whether to inject fused attention")
parser.add_argument("--inject_fused_mlp", action="store_true", help="Whether to inject fused mlp")
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel") parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
args = parser.parse_args() args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer) tokenizer = AutoTokenizer.from_pretrained(
args.model_name,
use_fast=args.use_fast_tokenizer,
trust_remote_code=args.trust_remote_code
)
if not tokenizer.pad_token_id: if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
@ -68,8 +74,8 @@ if __name__ == "__main__":
model_basename=args.model_basename, model_basename=args.model_basename,
use_safetensors=args.use_safetensors, use_safetensors=args.use_safetensors,
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
inject_fused_mlp=False, inject_fused_mlp=args.inject_fused_mlp,
inject_fused_attention=False, inject_fused_attention=args.inject_fused_attention,
disable_exllama=args.disable_exllama disable_exllama=args.disable_exllama
) )
else: else:

View file

@ -68,7 +68,10 @@ requirements = [
"datasets", "datasets",
"numpy", "numpy",
"rouge", "rouge",
"torch>=1.13.0", "torch>=2.0.1",
"functorch",
"xformers>=0.0.20",
"vllm",
"safetensors", "safetensors",
"transformers>=4.31.0", "transformers>=4.31.0",
"peft" "peft"