Compare commits
36 commits
main
...
xformers_i
Author | SHA1 | Date | |
---|---|---|---|
|
d95661b250 | ||
|
0a04d3fb2a | ||
|
7c2ec905a6 | ||
|
b1c64d9269 | ||
|
eeee7b344f | ||
|
8fedbbf82d | ||
|
fdb8c4500a | ||
|
43b9a5cd0a | ||
|
efe47aafe5 | ||
|
3d09cf36d7 | ||
|
beab695c5b | ||
|
edc5b72da4 | ||
|
26dc6852fe | ||
|
d73ed1cfc2 | ||
|
e5f874e5af | ||
|
700406e6b6 | ||
|
2092a80b81 | ||
|
4aea0aef39 | ||
|
1f9717af7f | ||
|
7a70bcf6d8 | ||
|
57c3e5b7d5 | ||
|
01ce32553e | ||
|
677409e2fe | ||
|
9155ef3038 | ||
|
df24da5797 | ||
|
ab6faa6496 | ||
|
f67b512cee | ||
|
bacac399d3 | ||
|
c71f5cdf12 | ||
|
0fcfddda90 | ||
|
2826729e73 | ||
|
801610367d | ||
|
7d0909160c | ||
|
8b19122775 | ||
|
cd8a674002 | ||
|
116d8267d7 |
18 changed files with 1197 additions and 145 deletions
|
@ -17,11 +17,11 @@ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
|||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from xformers.ops.fmha import AttentionOp
|
||||
|
||||
from ._const import *
|
||||
from ._utils import *
|
||||
from ..nn_modules.qlinear import GeneralQuantLinear
|
||||
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
||||
from ..quantization import GPTQ
|
||||
from ..utils.data_utils import collate_data
|
||||
from ..utils.import_utils import (
|
||||
|
@ -73,7 +73,7 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
quantize_config_filename = "quantize_config.json"
|
||||
if os.path.isdir(save_dir): # Local
|
||||
resolved_config_file = join(save_dir, quantize_config_filename)
|
||||
else: # Remote
|
||||
else: # Remote
|
||||
resolved_config_file = cached_file(
|
||||
save_dir,
|
||||
quantize_config_filename,
|
||||
|
@ -89,7 +89,6 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
with open(resolved_config_file, "r", encoding="utf-8") as f:
|
||||
return cls(**json.load(f))
|
||||
|
||||
|
@ -114,9 +113,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
inside_layer_modules: List[List[str]] = None
|
||||
lm_head_name: str = "lm_head"
|
||||
|
||||
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
||||
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
|
@ -210,14 +206,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if self.quantized:
|
||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("triton is not installed, reset use_triton to False")
|
||||
logger.warning("Triton is not installed, reset use_triton to False")
|
||||
use_triton = False
|
||||
|
||||
device_map = self.hf_device_map
|
||||
if device_map:
|
||||
for name, device in device_map.items():
|
||||
if device == "cpu":
|
||||
logger.info(f"truly offloading {name} to cpu with hook.")
|
||||
logger.info(f"Truly offloading {name} to cpu with hook.")
|
||||
module = get_module_by_name_suffix(self.model, name)
|
||||
remove_hook_from_module(module, recurse=True)
|
||||
accelerate.cpu_offload_with_hook(module, CUDA_0)
|
||||
|
@ -437,10 +433,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
def generate(self, *args, **kwargs):
|
||||
"""shortcut for model.generate"""
|
||||
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
||||
return self.model.generate(**kwargs)
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=self.device.type):
|
||||
return self.model.generate(*args, **kwargs)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
"""shortcut for model.prepare_inputs_for_generation"""
|
||||
|
@ -489,8 +485,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
if (self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)) and save_dir is None:
|
||||
raise ValueError("Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading.")
|
||||
if (
|
||||
self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)
|
||||
) and save_dir is None:
|
||||
raise ValueError(
|
||||
"Quantized model should be saved first, or you can provide save_dir to "
|
||||
"make sure model is saved to local disk before uploading."
|
||||
)
|
||||
|
||||
if save_dir is not None:
|
||||
logger.info(f"Saving model to {save_dir}")
|
||||
|
@ -517,16 +518,30 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
repo_type="model",
|
||||
)
|
||||
|
||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
|
||||
def save_quantized(
|
||||
self,
|
||||
save_dir: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""save quantized model and configs to local disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not self.quantized:
|
||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||
raise TypeError("Can only save quantized model, please execute .quantize() method first.")
|
||||
if self.injected_fused_attention or self.injected_fused_mlp:
|
||||
raise TypeError(
|
||||
"At least one of attention modules and mlp modules are injected with fused ops, "
|
||||
"please disable 'inject_fused_attention' and 'inject_fused_mlp' at model loading stage, "
|
||||
"and don't call ._fuse_attention() and ._fuse_mlp() methods before calling this method."
|
||||
)
|
||||
|
||||
self.model.to(CPU)
|
||||
|
||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
model_base_name = (
|
||||
self.quantize_config.model_file_base_name or
|
||||
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
)
|
||||
if use_safetensors:
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
|
@ -546,13 +561,23 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
new_key = str(key)
|
||||
new_value = str(value)
|
||||
except Exception as e:
|
||||
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
||||
raise TypeError(
|
||||
f"safetensors_metadata: both keys and values must be strings and "
|
||||
f"an error occured when trying to convert them: {e}"
|
||||
)
|
||||
if new_key in new_safetensors_metadata:
|
||||
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
||||
logger.warning(
|
||||
f"After converting safetensors_metadata keys to strings, the key "
|
||||
f"'{new_key}' is duplicated. Ensure that all your metadata keys are "
|
||||
f"strings to avoid overwriting."
|
||||
)
|
||||
new_safetensors_metadata[new_key] = new_value
|
||||
safetensors_metadata = new_safetensors_metadata
|
||||
if converted_keys:
|
||||
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
||||
logger.debug(
|
||||
f"One or more safetensors_metadata keys or values had to be converted to str(). "
|
||||
f"Final safetensors_metadata: {safetensors_metadata}"
|
||||
)
|
||||
|
||||
# Format is required to enable Accelerate to load the metadata
|
||||
# otherwise it raises an OSError
|
||||
|
@ -576,9 +601,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
self.quantize_config.model_name_or_path = save_dir
|
||||
self.quantize_config.model_file_base_name = model_base_name
|
||||
|
||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_dir: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""alias of save_quantized"""
|
||||
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
logger.warning("You are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
|
||||
@classmethod
|
||||
|
@ -672,9 +703,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
||||
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
model.eval()
|
||||
|
||||
return cls(model, False, quantize_config)
|
||||
|
||||
|
@ -688,8 +718,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
low_cpu_mem_usage: bool = False,
|
||||
use_triton: bool = False,
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
inject_fused_attention: bool = False,
|
||||
inject_fused_mlp: bool = False,
|
||||
use_cuda_fp16: bool = True,
|
||||
quantize_config: Optional[BaseQuantizeConfig] = None,
|
||||
model_basename: Optional[str] = None,
|
||||
|
@ -697,6 +727,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
disable_exllama: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -727,7 +758,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
}
|
||||
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||
logger.warning("Triton is not installed, reset use_triton to False")
|
||||
use_triton = False
|
||||
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
|
||||
logger.warning(
|
||||
|
@ -746,9 +777,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"2. You are using pytorch without CUDA support.\n"
|
||||
"3. CUDA and nvcc are not installed in your device."
|
||||
)
|
||||
if any([inject_fused_attention, inject_fused_mlp]) and trainable:
|
||||
logger.warning(
|
||||
"Neither fused attention nor fused mlp is tested under trainable mode, "
|
||||
"this may cause unexpected behavior or lead to error if you are training "
|
||||
"a quantized model with fused ops, please consider disabling 'inject_fused_attention' "
|
||||
"and 'inject_fused_mlp'."
|
||||
)
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**cached_file_kwargs
|
||||
)
|
||||
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
@ -787,7 +829,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if resolved_archive_file is not None:
|
||||
break
|
||||
|
||||
if resolved_archive_file is None: # Could not find a model file to use
|
||||
if resolved_archive_file is None: # Could not find a model file to use
|
||||
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||
|
||||
model_save_name = resolved_archive_file
|
||||
|
@ -795,9 +837,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if not disable_exllama and trainable:
|
||||
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
disable_exllama = True
|
||||
|
||||
elif not use_triton and trainable:
|
||||
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
logger.warning(
|
||||
"QuantLinear with cuda backend not support trainable mode yet, will switch to pytorch backend, "
|
||||
"this may cause very slow inference speed, disable trainable if you are not training model."
|
||||
)
|
||||
|
||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||
def skip(*args, **kwargs):
|
||||
|
@ -881,7 +925,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
)
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
|
||||
# == step4: set seqlen == #
|
||||
# == step4: post init model == #
|
||||
# Any post-initialization that require device information, for example buffers initialization on device.
|
||||
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
|
||||
|
||||
# == step5: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
|
@ -890,50 +938,62 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
||||
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
|
||||
# == step5: (optional) inject optimized module == #
|
||||
# == step6: (optional) inject optimized module == #
|
||||
if inject_fused_attention:
|
||||
if cls.fused_attn_module_type is None:
|
||||
try:
|
||||
cls._fuse_attention(model, attn_op, trainable)
|
||||
except NotImplementedError:
|
||||
inject_fused_attention = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
||||
else:
|
||||
cls.fused_attn_module_type.inject_to_model(
|
||||
model,
|
||||
use_triton=use_triton,
|
||||
group_size=quantize_config.group_size,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable,
|
||||
bits=quantize_config.bits,
|
||||
disable_exllama=disable_exllama,
|
||||
logger.warning(
|
||||
f"{cls.__name__} doesn't support fusing attention yet, will skip inject fused attention."
|
||||
)
|
||||
except:
|
||||
logger.error(
|
||||
f"Inject fused attention failed, you can set 'inject_fused_attention' to False to "
|
||||
f"bypass the error for now and report it on github."
|
||||
)
|
||||
raise
|
||||
if inject_fused_mlp:
|
||||
if cls.fused_mlp_module_type is None:
|
||||
try:
|
||||
cls._fuse_mlp(model, trainable)
|
||||
except NotImplementedError:
|
||||
inject_fused_mlp = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
||||
else:
|
||||
cls.fused_mlp_module_type.inject_to_model(
|
||||
model,
|
||||
use_triton=use_triton
|
||||
logger.warning(
|
||||
f"{cls.__name__} doesn't support fusing mlp yet, will skip inject fused mlp."
|
||||
)
|
||||
except:
|
||||
logger.error(
|
||||
f"Inject fused mlp failed, you can set 'inject_fused_mlp' to False to "
|
||||
f"bypass the error for now and report it on github."
|
||||
)
|
||||
raise
|
||||
if inject_fused_attention or inject_fused_mlp:
|
||||
logger.warning(
|
||||
"You are using at least one of 'inject_fused_attention' and 'inject_fused_mlp' "
|
||||
"modes, which are now marked as experimental features, feel free to open an issue "
|
||||
"or ask any question about those two features on github if you encounter unexpected "
|
||||
"behaviors and errors."
|
||||
)
|
||||
|
||||
# Any post-initialization that require device information, for example buffers initialization on device.
|
||||
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
|
||||
|
||||
model.eval()
|
||||
# == step6: (optional) warmup triton == #
|
||||
# == step7: (optional) warmup triton == #
|
||||
if use_triton and warmup_triton:
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
QuantLinear.warmup(model, seqlen=model.seqlen)
|
||||
cls.warmup_triton(model)
|
||||
|
||||
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
||||
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
||||
|
||||
# == step7: make model compatible with peft
|
||||
cls.make_sure_compatible_with_peft(
|
||||
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits
|
||||
# == step8: convert all QuantLinear to sub-class of torch.nn.Linear
|
||||
# note if _fuse_attention() and _fuse_mlp() is implemented,
|
||||
# all QuantLinear will be converted to sub-class of torch.nn.Linear at injection stage
|
||||
GeneralQuantLinear.convert_to_torch_linear(
|
||||
model,
|
||||
dynamically_import_QuantLinear(
|
||||
use_triton,
|
||||
quantize_config.desc_act,
|
||||
quantize_config.group_size,
|
||||
quantize_config.bits,
|
||||
disable_exllama
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
|
@ -942,39 +1002,35 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
quantize_config,
|
||||
is_triton_backend=use_triton,
|
||||
injected_fused_attention=inject_fused_attention,
|
||||
injected_fused_mlp=inject_fused_mlp and use_triton,
|
||||
injected_fused_mlp=inject_fused_mlp,
|
||||
trainable=trainable
|
||||
)
|
||||
|
||||
def warmup_triton(self, enabled: bool = True):
|
||||
@staticmethod
|
||||
def _fuse_attention(
|
||||
model: PreTrainedModel,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _fuse_mlp(
|
||||
model: PreTrainedModel,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def warmup_triton(model: nn.Module, enabled: bool = True) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
if not TRITON_AVAILABLE:
|
||||
logger.warning(f"triton is not available, skip warmup stage directly.")
|
||||
logger.warning(f"Triton is not available, skip warmup stage directly.")
|
||||
return
|
||||
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
|
||||
|
||||
if self.fused_mlp_module_type is not None:
|
||||
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
||||
|
||||
def enable_trainable_mode(self, enabled: bool = True):
|
||||
if not self.is_triton_backend and enabled:
|
||||
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "trainable"):
|
||||
setattr(m, "trainable", enabled)
|
||||
|
||||
def disable_trainable_mode(self):
|
||||
self.enable_trainable_mode(enabled=False)
|
||||
|
||||
@staticmethod
|
||||
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
|
||||
GeneralQuantLinear.inject_to_model(
|
||||
model,
|
||||
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
|
||||
)
|
||||
QuantLinear.warmup(model, seqlen=model.seqlen)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
from packaging.version import parse as parse_version
|
||||
|
||||
from torch import device
|
||||
|
||||
from ..utils.import_utils import compare_transformers_version
|
||||
|
||||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
|
@ -20,9 +16,8 @@ SUPPORTED_MODELS = [
|
|||
"RefinedWeb",
|
||||
"baichuan",
|
||||
"internlm",
|
||||
"llama",
|
||||
"qwen",
|
||||
]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from inspect import signature
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from xformers.ops.fmha import AttentionOp
|
||||
|
||||
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
|
||||
from ._utils import check_and_get_model_type
|
||||
from .bloom import BloomGPTQForCausalLM
|
||||
|
@ -81,6 +83,7 @@ class AutoGPTQForCausalLM:
|
|||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
disable_exllama: bool = False,
|
||||
**kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
|
@ -121,6 +124,7 @@ class AutoGPTQForCausalLM:
|
|||
trust_remote_code=trust_remote_code,
|
||||
warmup_triton=warmup_triton,
|
||||
trainable=trainable,
|
||||
attn_op=attn_op,
|
||||
disable_exllama=disable_exllama,
|
||||
**keywords
|
||||
)
|
||||
|
|
|
@ -1,4 +1,19 @@
|
|||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import xformers.ops as xop
|
||||
from torch.cuda import empty_cache
|
||||
from transformers import PreTrainedModel
|
||||
from xformers.ops.fmha import AttentionOp
|
||||
|
||||
from ._base import *
|
||||
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
|
||||
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
|
||||
|
||||
|
||||
class BaiChuanFusedAttentionWithRope(FusedAttentionWithRoPE):
|
||||
pass
|
||||
|
||||
|
||||
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -12,5 +27,49 @@ class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.down_proj"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _fuse_attention(
|
||||
model: PreTrainedModel,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
model_config = model.config
|
||||
num_heads = model_config.num_attention_heads
|
||||
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||
layers = model.model.layers
|
||||
|
||||
rope_cache = build_rope_cache(
|
||||
rotary_dim=model_config.hidden_size // num_heads,
|
||||
max_position=model_config.max_position_embeddings,
|
||||
device=model.device,
|
||||
dtype=model.dtype
|
||||
)
|
||||
|
||||
for layer in layers:
|
||||
old_attn = layer.self_attn
|
||||
attn_device = old_attn.W_pack.qweight.data.device
|
||||
new_qkv_proj = FusedGeneralQuantLinear(old_attn.W_pack)
|
||||
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
|
||||
new_attn = BaiChuanFusedAttentionWithRope(
|
||||
qkv_proj=new_qkv_proj,
|
||||
out_proj=new_out_proj,
|
||||
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
|
||||
num_query_heads=num_heads,
|
||||
num_key_heads=num_heads,
|
||||
num_value_heads=num_heads,
|
||||
attn_dropout=0.0,
|
||||
resid_dropout=0.0,
|
||||
scale=scale,
|
||||
attention_ops=attn_op,
|
||||
outputs_handler=(lambda x, y, z: (x, z, y)),
|
||||
training=trainable
|
||||
)
|
||||
|
||||
layer.self_attn = new_attn
|
||||
|
||||
del old_attn
|
||||
|
||||
empty_cache()
|
||||
|
||||
|
||||
__all__ = ["BaiChuanGPTQForCausalLM"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
from ._base import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -14,4 +14,5 @@ class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.c_proj"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["GPTBigCodeGPTQForCausalLM"]
|
|
@ -1,5 +1,136 @@
|
|||
from copy import deepcopy
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import xformers.ops as xop
|
||||
from torch.cuda import empty_cache
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb
|
||||
from xformers.ops.fmha import AttentionOp
|
||||
|
||||
from ._base import *
|
||||
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
|
||||
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||
from ..nn_modules.fused_modules.attention import FusedAttention
|
||||
from ..nn_modules.fused_modules.mlp import FusedMLP
|
||||
|
||||
|
||||
class GPTJFusedAttention(FusedAttention):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
embed_positions: torch.Tensor,
|
||||
rotary_dim: Optional[int],
|
||||
num_query_heads: int,
|
||||
num_key_heads: int,
|
||||
num_value_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
resid_dropout: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
attention_ops: Optional[xop.AttentionOp] = None,
|
||||
outputs_handler: Optional[Callable] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
super(GPTJFusedAttention, self).__init__(
|
||||
qkv_proj,
|
||||
out_proj,
|
||||
num_query_heads,
|
||||
num_key_heads,
|
||||
num_value_heads,
|
||||
attn_dropout,
|
||||
resid_dropout,
|
||||
scale,
|
||||
attention_ops,
|
||||
outputs_handler,
|
||||
training
|
||||
)
|
||||
self.embed_positions = embed_positions
|
||||
self.rotary_dim = rotary_dim
|
||||
|
||||
def _get_embed_positions(self, position_ids: torch.Tensor):
|
||||
return self.embed_positions.repeat(position_ids.shape[0], 1, 1)
|
||||
|
||||
def _apply_rotary(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bsz, seq_len = key.shape[:2]
|
||||
|
||||
dtype = query.dtype
|
||||
query = query.view(bsz, seq_len, self.num_query_heads, -1).to(dtype=torch.float)
|
||||
key = key.view(bsz, seq_len, self.num_key_heads, -1).to(dtype=torch.float)
|
||||
|
||||
embed_positions = self._get_embed_positions(position_ids)
|
||||
|
||||
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
|
||||
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
|
||||
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
||||
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
k_pass = key[:, :, :, self.rotary_dim:]
|
||||
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim:]
|
||||
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
||||
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
key = apply_rotary_pos_emb(key, sin, cos)
|
||||
query = apply_rotary_pos_emb(query, sin, cos)
|
||||
|
||||
return query.view(bsz, seq_len, -1).to(dtype=dtype), key.view(bsz, seq_len, -1).to(dtype=dtype)
|
||||
|
||||
def _build_attn_bias(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Optional[xop.AttentionBias]:
|
||||
return xop.LowerTriangularMask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
**kwargs
|
||||
):
|
||||
bsz, seq_len = hidden_states.shape[:2]
|
||||
|
||||
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||
|
||||
if position_ids is not None:
|
||||
q, k = self._apply_rotary(q, k, position_ids)
|
||||
|
||||
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if layer_past is None else None
|
||||
attn_out, present = self._attn(
|
||||
bsz,
|
||||
seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
use_cache,
|
||||
layer_past
|
||||
)
|
||||
|
||||
out = self.out_proj(attn_out)
|
||||
out = self.resid_dropout(out)
|
||||
|
||||
outputs = (out, present, None)
|
||||
if self.outputs_handler:
|
||||
outputs = self.outputs_handler(*outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -13,7 +144,75 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.fc_out"]
|
||||
]
|
||||
|
||||
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel
|
||||
@staticmethod
|
||||
def _fuse_attention(
|
||||
model: PreTrainedModel,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
model_config = model.config
|
||||
num_heads = model_config.n_head
|
||||
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||
layers = model.transformer.h
|
||||
|
||||
for layer in layers:
|
||||
old_attn = layer.attn
|
||||
device = old_attn.q_proj.qweight.data.device
|
||||
new_qkv_proj = FusedGeneralQuantLinear.fuse(
|
||||
old_attn.q_proj,
|
||||
old_attn.k_proj,
|
||||
old_attn.v_proj
|
||||
)
|
||||
new_out_proj = FusedGeneralQuantLinear(old_attn.out_proj)
|
||||
new_attn = GPTJFusedAttention(
|
||||
qkv_proj=new_qkv_proj,
|
||||
out_proj=new_out_proj,
|
||||
embed_positions=old_attn.embed_positions.to(device),
|
||||
rotary_dim=old_attn.rotary_dim,
|
||||
num_query_heads=num_heads,
|
||||
num_key_heads=num_heads,
|
||||
num_value_heads=num_heads,
|
||||
attn_dropout=model_config.attn_pdrop,
|
||||
resid_dropout=model_config.resid_pdrop,
|
||||
scale=scale,
|
||||
attention_ops=attn_op,
|
||||
outputs_handler=None,
|
||||
training=trainable
|
||||
)
|
||||
|
||||
layer.attn = new_attn
|
||||
|
||||
del old_attn
|
||||
|
||||
empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _fuse_mlp(
|
||||
model: PreTrainedModel,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
model_config = model.config
|
||||
act = ACT2FN[model_config.activation_function]
|
||||
out_dropout = model_config.resid_pdrop
|
||||
layers = model.transformer.h
|
||||
|
||||
for layer in layers:
|
||||
old_mlp = layer.mlp
|
||||
new_mlp = FusedMLP(
|
||||
input_proj=FusedGeneralQuantLinear(old_mlp.fc_in),
|
||||
out_proj=FusedGeneralQuantLinear(old_mlp.fc_out),
|
||||
activation=act,
|
||||
in_dropout=0.0,
|
||||
out_dropout=out_dropout,
|
||||
training=trainable,
|
||||
residual=False
|
||||
)
|
||||
|
||||
layer.mlp = new_mlp
|
||||
|
||||
del old_mlp
|
||||
|
||||
empty_cache()
|
||||
|
||||
|
||||
__all__ = ["GPTJGPTQForCausalLM"]
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
from logging import getLogger
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
from torch.cuda import empty_cache
|
||||
from transformers import PreTrainedModel
|
||||
from xformers.ops.fmha import AttentionOp
|
||||
|
||||
from ._base import *
|
||||
from ..utils.import_utils import compare_transformers_version
|
||||
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
|
||||
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
|
||||
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
|
||||
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
|
||||
else:
|
||||
FusedLlamaAttentionForQuantizedModel = None
|
||||
FusedLlamaMLPForQuantizedModel = None
|
||||
|
||||
logger = getLogger(__name__)
|
||||
class LlamaFusedAttentionWithRoPE(FusedAttentionWithRoPE):
|
||||
pass
|
||||
|
||||
|
||||
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -24,8 +26,61 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.down_proj"]
|
||||
]
|
||||
|
||||
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
|
||||
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
|
||||
@staticmethod
|
||||
def _fuse_attention(
|
||||
model: PreTrainedModel,
|
||||
attn_op: Optional[AttentionOp] = None,
|
||||
trainable: bool = False
|
||||
) -> None:
|
||||
model_config = model.config
|
||||
num_heads = model_config.num_attention_heads
|
||||
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||
layers = model.model.layers
|
||||
|
||||
rope_cache = build_rope_cache(
|
||||
rotary_dim=model_config.hidden_size // num_heads,
|
||||
max_position=model_config.max_position_embeddings,
|
||||
base=10000,
|
||||
device=model.device,
|
||||
dtype=model.dtype
|
||||
)
|
||||
|
||||
for layer in layers:
|
||||
old_attn = layer.self_attn
|
||||
attn_device = old_attn.q_proj.qweight.data.device
|
||||
new_qkv_proj = FusedGeneralQuantLinear.fuse(
|
||||
old_attn.q_proj,
|
||||
old_attn.k_proj,
|
||||
old_attn.v_proj
|
||||
)
|
||||
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
|
||||
new_attn = LlamaFusedAttentionWithRoPE(
|
||||
qkv_proj=new_qkv_proj,
|
||||
out_proj=new_out_proj,
|
||||
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
|
||||
num_query_heads=num_heads,
|
||||
num_key_heads=num_heads,
|
||||
num_value_heads=num_heads,
|
||||
attn_dropout=0.0,
|
||||
resid_dropout=0.0,
|
||||
scale=scale,
|
||||
attention_ops=attn_op,
|
||||
outputs_handler=(lambda x, y, z: (x, z, y)),
|
||||
training=trainable
|
||||
)
|
||||
|
||||
layer.self_attn = new_attn
|
||||
|
||||
del old_attn
|
||||
|
||||
empty_cache()
|
||||
|
||||
# @staticmethod
|
||||
# def _fuse_mlp(
|
||||
# model: PreTrainedModel,
|
||||
# trainable: bool = False
|
||||
# ) -> None:
|
||||
# pass
|
||||
|
||||
|
||||
__all__ = ["LlamaGPTQForCausalLM"]
|
||||
|
|
0
auto_gptq/nn_modules/fused_modules/__init__.py
Normal file
0
auto_gptq/nn_modules/fused_modules/__init__.py
Normal file
394
auto_gptq/nn_modules/fused_modules/attention.py
Normal file
394
auto_gptq/nn_modules/fused_modules/attention.py
Normal file
|
@ -0,0 +1,394 @@
|
|||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import xformers.ops as xop
|
||||
from vllm import pos_encoding_ops as vllm_pos_encoding_ops
|
||||
from xformers.ops.fmha.attn_bias import LowerTriangularMask, LowerTriangularMaskWithTensorBias
|
||||
|
||||
|
||||
POTENTIAL_KV_CACHE_NAMES = (
|
||||
"past_key_value",
|
||||
"layer_past",
|
||||
"kv_cache"
|
||||
)
|
||||
|
||||
|
||||
def _try_to_get_kv_cache(**kwargs) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
kv_cache = None
|
||||
for name in POTENTIAL_KV_CACHE_NAMES:
|
||||
if name in kwargs:
|
||||
return kwargs[name]
|
||||
return kv_cache
|
||||
|
||||
|
||||
def build_rope_cache(
|
||||
rotary_dim: int,
|
||||
max_position: int = 2048,
|
||||
base: int = 10000,
|
||||
device: torch.device = torch.device("cuda:0"),
|
||||
dtype: torch.dtype = torch.float16
|
||||
): # TODO: support multiple scaling strategies
|
||||
inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
|
||||
t = torch.arange(max_position, device=device, dtype=dtype)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
def build_alibi_slopes(
|
||||
num_heads: int,
|
||||
device: torch.device = torch.device("cuda:0"),
|
||||
dtype: torch.dtype = torch.float16
|
||||
):
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
slopes = slopes.to(dtype)
|
||||
|
||||
return slopes
|
||||
|
||||
|
||||
def attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_ops: Optional[xop.AttentionOp] = (xop.fmha.flash.FwOp(), None),
|
||||
attention_bias: Optional[xop.AttentionBias] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None
|
||||
):
|
||||
if value.shape[2] != query.shape[2]:
|
||||
# MQA expand
|
||||
if value.shape[2] == 1:
|
||||
pass # TODO
|
||||
# GQA reshape
|
||||
else:
|
||||
original_shape = value.shape
|
||||
pass # TODO
|
||||
|
||||
return xop.memory_efficient_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attention_bias,
|
||||
p=p,
|
||||
scale=scale,
|
||||
op=attention_ops
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
num_query_heads: int,
|
||||
num_key_heads: int,
|
||||
num_value_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
resid_dropout: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
attention_ops: Optional[xop.AttentionOp] = None,
|
||||
outputs_handler: Optional[Callable] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
super(FusedAttention, self).__init__()
|
||||
|
||||
self.qkv_proj = qkv_proj
|
||||
self.out_proj = out_proj
|
||||
|
||||
self.num_query_heads = num_query_heads
|
||||
self.num_key_heads = num_key_heads
|
||||
self.num_value_heads = num_value_heads
|
||||
|
||||
self.attn_dropout = attn_dropout if training else 0.0
|
||||
self.scale = scale
|
||||
|
||||
self.attention_ops = attention_ops
|
||||
|
||||
self.outputs_handler = outputs_handler
|
||||
|
||||
self.resid_dropout = nn.Dropout(resid_dropout if training else 0.0)
|
||||
|
||||
def _build_attn_bias(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Optional[xop.AttentionBias]:
|
||||
return None
|
||||
|
||||
def _attn(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attention_bias: Optional[xop.AttentionBias] = None,
|
||||
use_cache: bool = False,
|
||||
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
):
|
||||
q = q.view(batch_size, seq_len, self.num_query_heads, -1).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_key_heads, -1).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_value_heads, -1).transpose(1, 2)
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
k = torch.cat((k_cache, k), dim=2)
|
||||
v = torch.cat((v_cache, v), dim=2)
|
||||
|
||||
present = None
|
||||
if use_cache:
|
||||
present = (k, v)
|
||||
|
||||
attn_out = attention(
|
||||
query=q.transpose(1, 2),
|
||||
key=k.transpose(1, 2),
|
||||
value=v.transpose(1, 2),
|
||||
attention_ops=self.attention_ops,
|
||||
attention_bias=attention_bias,
|
||||
p=self.attn_dropout,
|
||||
scale=self.scale
|
||||
).view(batch_size, seq_len, -1)
|
||||
|
||||
return attn_out, present
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
bsz, seq_len = hidden_states.shape[:2]
|
||||
use_cache = kwargs.get("use_cache", False)
|
||||
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||
|
||||
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||
|
||||
attn_bias = self._build_attn_bias(hidden_states, attention_mask)
|
||||
attn_out, present = self._attn(
|
||||
bsz,
|
||||
seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
use_cache,
|
||||
kv_cache
|
||||
)
|
||||
|
||||
out = self.out_proj(attn_out)
|
||||
out = self.resid_dropout(out)
|
||||
|
||||
outputs = (out, present, None)
|
||||
if self.outputs_handler:
|
||||
outputs = self.outputs_handler(*outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FusedAttentionWithRoPE(FusedAttention):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
num_query_heads: int,
|
||||
num_key_heads: int,
|
||||
num_value_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
resid_dropout: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
attention_ops: Optional[xop.AttentionOp] = None,
|
||||
outputs_handler: Optional[Callable] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
super(FusedAttentionWithRoPE, self).__init__(
|
||||
qkv_proj=qkv_proj,
|
||||
out_proj=out_proj,
|
||||
num_query_heads=num_query_heads,
|
||||
num_key_heads=num_key_heads,
|
||||
num_value_heads=num_value_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
resid_dropout=resid_dropout,
|
||||
scale=scale,
|
||||
attention_ops=attention_ops,
|
||||
outputs_handler=outputs_handler,
|
||||
training=training
|
||||
)
|
||||
|
||||
self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
|
||||
|
||||
def _build_attn_bias(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Optional[xop.AttentionBias]:
|
||||
return LowerTriangularMask()
|
||||
|
||||
def _apply_rotary_embedding(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None
|
||||
):
|
||||
bsz, seq_len, hidden_size = query.shape
|
||||
|
||||
if position_ids is not None:
|
||||
query = query.view(bsz * seq_len, -1)
|
||||
key = key.view(bsz * seq_len, -1)
|
||||
vllm_pos_encoding_ops.rotary_embedding_neox(
|
||||
position_ids.view(-1).to(query.device),
|
||||
query,
|
||||
key,
|
||||
hidden_size // self.num_query_heads,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
query = query.view(bsz, seq_len, -1)
|
||||
key = key.view(bsz, seq_len, -1)
|
||||
|
||||
return query, key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
bsz, seq_len = hidden_states.shape[:2]
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
use_cache = kwargs.get("use_cache", False)
|
||||
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||
|
||||
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||
|
||||
q, k = self._apply_rotary_embedding(q, k, position_ids)
|
||||
|
||||
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
||||
attn_out, present = self._attn(
|
||||
bsz,
|
||||
seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
use_cache,
|
||||
kv_cache
|
||||
)
|
||||
|
||||
out = self.out_proj(attn_out)
|
||||
out = self.resid_dropout(out)
|
||||
|
||||
outputs = (out, present, None)
|
||||
if self.outputs_handler:
|
||||
outputs = self.outputs_handler(*outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FusedAttentionWithALiBi(FusedAttention):
|
||||
def __init__(
|
||||
self,
|
||||
qkv_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_query_heads: int,
|
||||
num_key_heads: int,
|
||||
num_value_heads: int,
|
||||
attn_dropout: float = 0.0,
|
||||
resid_dropout: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
attention_ops: Optional[xop.AttentionOp] = None,
|
||||
outputs_handler: Optional[Callable] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
super(FusedAttentionWithALiBi, self).__init__(
|
||||
qkv_proj=qkv_proj,
|
||||
out_proj=out_proj,
|
||||
num_query_heads=num_query_heads,
|
||||
num_key_heads=num_key_heads,
|
||||
num_value_heads=num_value_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
resid_dropout=resid_dropout,
|
||||
scale=scale,
|
||||
attention_ops=attention_ops,
|
||||
outputs_handler=outputs_handler,
|
||||
training=training
|
||||
)
|
||||
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
def _build_attn_bias(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Optional[xop.AttentionBias]: # adopt from vllm
|
||||
bsz, seq_len = hidden_states.shape[:2]
|
||||
|
||||
bias = torch.arange(seq_len)
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(hidden_states.device)
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
self.num_query_heads,
|
||||
padded_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
)[:, :seq_len, :seq_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
bias = LowerTriangularMaskWithTensorBias(bias.unsqueeze(0).repeat(bsz, 1, 1, 1))
|
||||
|
||||
return bias
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
bsz, seq_len = hidden_states.shape[:2]
|
||||
use_cache = kwargs.get("use_cache", False)
|
||||
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||
|
||||
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||
|
||||
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
||||
attn_out, present = self._attn(
|
||||
bsz,
|
||||
seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_bias,
|
||||
use_cache,
|
||||
kv_cache
|
||||
)
|
||||
|
||||
out = self.out_proj(attn_out)
|
||||
out = self.resid_dropout(out)
|
||||
|
||||
outputs = (out, present, None)
|
||||
if self.outputs_handler:
|
||||
outputs = self.outputs_handler(*outputs)
|
||||
|
||||
return outputs
|
97
auto_gptq/nn_modules/fused_modules/linear.py
Normal file
97
auto_gptq/nn_modules/fused_modules/linear.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
|
||||
from ..qlinear import GeneralQuantLinear
|
||||
from ..qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
|
||||
from ..qlinear.qlinear_cuda_old import QuantLinear as OldCudaQuantLinear
|
||||
try:
|
||||
from ..qlinear.qlinear_triton import QuantLinear as TritonQuantLinear
|
||||
except:
|
||||
TritonQuantLinear = None
|
||||
try:
|
||||
from ..qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
|
||||
except:
|
||||
ExllamaQuantLinear = None
|
||||
|
||||
|
||||
class FusedGeneralQuantLinear(GeneralQuantLinear):
|
||||
def __init__(self, quant_linear_module):
|
||||
super(FusedGeneralQuantLinear, self).__init__(quant_linear_module)
|
||||
|
||||
@classmethod
|
||||
def fuse(
|
||||
cls,
|
||||
q_proj,
|
||||
k_proj=None,
|
||||
v_proj=None,
|
||||
):
|
||||
qweights, qzeros, scales, g_idx, bias = [], [], [], [], []
|
||||
outfeatures = 0
|
||||
for module in [q_proj, k_proj, v_proj]:
|
||||
if module is not None:
|
||||
qweights.append(module.qweight)
|
||||
qzeros.append(module.qzeros)
|
||||
scales.append(module.scales)
|
||||
g_idx.append(module.g_idx)
|
||||
bias.append(module.bias)
|
||||
outfeatures += module.outfeatures
|
||||
|
||||
if bias[0] is None:
|
||||
bias = None
|
||||
|
||||
if len(qweights) > 1:
|
||||
qweights = torch.cat(qweights, dim=1)
|
||||
qzeros = torch.cat(qzeros, dim=1)
|
||||
scales = torch.cat(scales, dim=1)
|
||||
g_idx = torch.cat(g_idx, dim=0)
|
||||
if bias is not None:
|
||||
bias = torch.cat(bias, dim=0)
|
||||
|
||||
qlinear_args = (
|
||||
q_proj.bits,
|
||||
q_proj.group_size,
|
||||
q_proj.infeatures,
|
||||
outfeatures,
|
||||
bias is not None
|
||||
)
|
||||
qlinear_kwargs = {"trainable": q_proj.trainable}
|
||||
if isinstance(q_proj, (OldCudaQuantLinear, CudaQuantLinear)):
|
||||
qlinear_kwargs["kernel_switch_threshold"] = q_proj.kernel_switch_threshold
|
||||
if isinstance(q_proj, OldCudaQuantLinear):
|
||||
qlinear_kwargs["use_cuda_fp16"] = q_proj.use_cuda_fp16
|
||||
QuantLinear = OldCudaQuantLinear
|
||||
else:
|
||||
QuantLinear = CudaQuantLinear
|
||||
elif isinstance(q_proj, TritonQuantLinear):
|
||||
QuantLinear = TritonQuantLinear
|
||||
else:
|
||||
QuantLinear = ExllamaQuantLinear
|
||||
fused_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
|
||||
fused_proj.qweight = qweights
|
||||
fused_proj.qzeros = qzeros
|
||||
fused_proj.scales = scales
|
||||
fused_proj.g_idx = g_idx
|
||||
fused_proj.bias = bias
|
||||
|
||||
if isinstance(q_proj, ExllamaQuantLinear):
|
||||
if not hasattr(q_proj, "_use_act_order"):
|
||||
raise AttributeError(
|
||||
"q_proj doesn't have attribute _use_act_order, please execute "
|
||||
"auto_gptq.modeling._utils.autogptq_post_init function before "
|
||||
"fuse quant linears."
|
||||
)
|
||||
if q_proj._use_act_order:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
raise ValueError(
|
||||
"Exllama kernel does not support layer fusion with act-order. "
|
||||
"Please either use inject_fused_attention=False or disable_exllama=True."
|
||||
)
|
||||
fused_proj._use_act_order = q_proj._use_act_order
|
||||
fused_proj.g_idx = None
|
||||
fused_proj.post_init()
|
||||
|
||||
del q_proj, k_proj, v_proj
|
||||
|
||||
return cls(fused_proj)
|
168
auto_gptq/nn_modules/fused_modules/mlp.py
Normal file
168
auto_gptq/nn_modules/fused_modules/mlp.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
from functools import partial
|
||||
from typing import Union, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def act_dropout(
|
||||
hidden_states: torch.Tensor,
|
||||
activation: Union[Callable, nn.Module],
|
||||
dropout: float = 0.0
|
||||
):
|
||||
hidden_states = activation(hidden_states)
|
||||
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||
|
||||
|
||||
def dropout_res(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
dropout: float = 0.0
|
||||
):
|
||||
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||
return torch.add(hidden_states, residual)
|
||||
|
||||
|
||||
def act_dropout_res(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
activation: Union[Callable, nn.Module],
|
||||
dropout: float = 0.0
|
||||
):
|
||||
hidden_states = activation(hidden_states)
|
||||
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||
return torch.add(hidden_states, residual)
|
||||
|
||||
|
||||
class NVFusedActDropoutRes(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||
dropout: float = 0.0,
|
||||
residual: bool = False,
|
||||
is_cuda: bool = False
|
||||
):
|
||||
super(NVFusedActDropoutRes, self).__init__()
|
||||
|
||||
fn = partial(F.dropout, p=dropout)
|
||||
if activation is not None and residual:
|
||||
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
|
||||
elif activation is not None:
|
||||
fn = partial(act_dropout, activation=activation, dropout=dropout)
|
||||
elif residual:
|
||||
fn = partial(dropout_res, dropout=dropout)
|
||||
|
||||
self.fn = fn
|
||||
if is_cuda:
|
||||
self.fn = memory_efficient_fusion(self.fn)
|
||||
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
||||
if self.residual:
|
||||
return self.fn(hidden_states, residual)
|
||||
else:
|
||||
return self.fn(hidden_states)
|
||||
|
||||
|
||||
class FusedMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||
in_dropout: float = 0.0,
|
||||
out_dropout: float = 0.0,
|
||||
training: bool = False,
|
||||
residual: bool = False
|
||||
):
|
||||
super(FusedMLP, self).__init__()
|
||||
|
||||
if activation is None:
|
||||
activation = nn.Identity()
|
||||
|
||||
is_cuda = input_proj.weight.data.device.type == "cuda"
|
||||
|
||||
self.input_proj = input_proj
|
||||
self.fused_op1 = NVFusedActDropoutRes(
|
||||
activation=activation,
|
||||
dropout=in_dropout if training else 0.0,
|
||||
residual=False,
|
||||
is_cuda=is_cuda
|
||||
)
|
||||
self.out_proj = out_proj
|
||||
self.fused_op2 = NVFusedActDropoutRes(
|
||||
activation=None,
|
||||
dropout=out_dropout if training else 0.0,
|
||||
residual=residual,
|
||||
is_cuda=is_cuda
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
||||
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
|
||||
|
||||
|
||||
def gated_act_dropout(
|
||||
gate_states: torch.Tensor,
|
||||
up_states: torch.Tensor,
|
||||
activation: Union[Callable, nn.Module],
|
||||
dropout: float = 0.0
|
||||
):
|
||||
hidden_states = activation(gate_states) * up_states
|
||||
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||
|
||||
|
||||
class NVFusedGatedActDropout(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||
dropout: float = 0.0,
|
||||
is_cuda: bool = False
|
||||
):
|
||||
super(NVFusedGatedActDropout, self).__init__()
|
||||
|
||||
fn = partial(F.dropout, p=dropout)
|
||||
if activation is not None:
|
||||
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
|
||||
|
||||
self.fn = fn
|
||||
if is_cuda:
|
||||
self.fn = memory_efficient_fusion(self.fn)
|
||||
|
||||
def forward(self, gate_states: torch.Tensor, up_states):
|
||||
if isinstance(self.fn, nn.Dropout):
|
||||
return self.fn(gate_states * up_states)
|
||||
return self.fn(gate_states, up_states)
|
||||
|
||||
|
||||
class FusedGatedMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_proj: nn.Linear,
|
||||
out_proj: nn.Linear,
|
||||
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||
in_dropout: float = 0.0,
|
||||
out_dropout: float = 0.0,
|
||||
training: bool = False
|
||||
):
|
||||
super(FusedGatedMLP, self).__init__()
|
||||
|
||||
if activation is None:
|
||||
activation = nn.Identity()
|
||||
|
||||
self.input_proj = input_proj
|
||||
self.fused_op = NVFusedGatedActDropout(
|
||||
activation=activation,
|
||||
dropout=in_dropout if training else 0.0,
|
||||
is_cuda=input_proj.weight.data.device.type == "cuda"
|
||||
)
|
||||
self.out_proj = out_proj
|
||||
self.out_dropout = nn.Dropout(out_dropout)
|
||||
|
||||
self.intermediate_size = self.input_proj.out_features // 2
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = self.input_proj(hidden_states)
|
||||
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))
|
|
@ -6,7 +6,7 @@ class GeneralQuantLinear(nn.Linear):
|
|||
super().__init__(
|
||||
in_features=quant_linear_module.infeatures,
|
||||
out_features=quant_linear_module.outfeatures,
|
||||
bias=True
|
||||
bias=quant_linear_module.bias is not None
|
||||
)
|
||||
self.infeatures = quant_linear_module.infeatures
|
||||
self.outfeatures = quant_linear_module.outfeatures
|
||||
|
@ -18,28 +18,47 @@ class GeneralQuantLinear(nn.Linear):
|
|||
|
||||
self.weight.data = quant_linear_module.qweight
|
||||
self.register_buffer('qweight', quant_linear_module.qweight)
|
||||
self.bias.data = quant_linear_module.bias
|
||||
|
||||
self.qweight.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
if quant_linear_module.bias is not None:
|
||||
self.bias.data = quant_linear_module.bias
|
||||
|
||||
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
||||
self.register_buffer('scales', quant_linear_module.scales)
|
||||
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
||||
|
||||
# arg of qlinear_cuda and qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "wf"):
|
||||
self.wf = quant_linear_module.wf
|
||||
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "kernel_switch_threshold"):
|
||||
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
|
||||
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
||||
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
||||
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "autogptq_cuda"):
|
||||
self.autogptq_cuda = quant_linear_module.autogptq_cuda
|
||||
# arg of qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "half_indim"):
|
||||
self.half_indim = quant_linear_module.half_indim
|
||||
# arg of qlinear_cuda_old
|
||||
if hasattr(quant_linear_module, "use_cuda_fp16"):
|
||||
self.use_cuda_fp16 = quant_linear_module.use_cuda_fp16
|
||||
# args of qlinear_exllama
|
||||
if hasattr(quant_linear_module, "_use_act_order"):
|
||||
self._use_act_order = quant_linear_module._use_act_order
|
||||
# arg of qlinaer_exllama
|
||||
if hasattr(quant_linear_module, "width"):
|
||||
self.width = quant_linear_module.width
|
||||
# arg of qlinear_exllama
|
||||
if hasattr(quant_linear_module, "q4"):
|
||||
self.q4 = quant_linear_module.q4
|
||||
|
||||
self.trainable = quant_linear_module.trainable
|
||||
|
||||
self.forward = quant_linear_module.forward
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, target_module_type):
|
||||
def convert_to_torch_linear(cls, model: nn.Module, target_module_type: "QuantLinear"):
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, target_module_type):
|
||||
continue
|
||||
|
|
|
@ -36,8 +36,6 @@ class QuantLinear(nn.Module):
|
|||
global _autogptq_cuda_available
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
_autogptq_cuda_available = False
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
|
@ -198,7 +196,7 @@ class QuantLinear(nn.Module):
|
|||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.autogptq_cuda_available and (
|
||||
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
|
||||
):
|
||||
) and not self.trainable:
|
||||
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
|
||||
if self.bits == 2:
|
||||
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
|
|
|
@ -36,8 +36,7 @@ class QuantLinear(nn.Module):
|
|||
global _autogptq_cuda_available
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
_autogptq_cuda_available = False
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
|
@ -198,7 +197,7 @@ class QuantLinear(nn.Module):
|
|||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.autogptq_cuda_available is True and (
|
||||
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
|
||||
):
|
||||
) and not self.trainable:
|
||||
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
|
||||
if self.use_cuda_fp16:
|
||||
x = x.half()
|
||||
|
|
|
@ -354,7 +354,7 @@ def get_gptq_peft_model(
|
|||
train_mode: bool = False
|
||||
):
|
||||
if train_mode and not model.trainable:
|
||||
model.enable_trainable_mode()
|
||||
raise TypeError("model is not trainable, please load model with 'trainable=True'")
|
||||
if train_mode and not peft_config:
|
||||
raise ValueError("peft_config not specified when in train mode.")
|
||||
if not train_mode and not model_id:
|
||||
|
|
|
@ -87,30 +87,29 @@ def load_data(data_path, tokenizer, n_samples, max_new_tokens):
|
|||
outputs = examples["output"]
|
||||
|
||||
prompts = []
|
||||
texts = []
|
||||
outs = []
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
for istr, inp, opt in zip(instructions, inputs, outputs):
|
||||
if inp:
|
||||
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
|
||||
text = prompt + opt
|
||||
else:
|
||||
prompt = f"Instruction:\n{istr}\nOutput:\n"
|
||||
text = prompt + opt
|
||||
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
|
||||
continue
|
||||
|
||||
tokenized_data = tokenizer(text)
|
||||
tokenized_data = tokenizer(prompt)
|
||||
|
||||
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
|
||||
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
|
||||
prompts.append(prompt)
|
||||
texts.append(text)
|
||||
outs.append(opt)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"prompt": prompts
|
||||
"prompt": prompts,
|
||||
"output": outs
|
||||
}
|
||||
|
||||
dataset = Dataset.from_generator(dummy_gen)
|
||||
|
@ -236,9 +235,9 @@ def main():
|
|||
parser.add_argument("--use_triton", action="store_true")
|
||||
parser.add_argument("--use_safetensors", action="store_true")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
parser.add_argument("--inject_fused_attention", action="store_true")
|
||||
parser.add_argument("--inject_fused_mlp", action="store_true")
|
||||
parser.add_argument("--disable_exllama", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_attention", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_mlp", action="store_true")
|
||||
parser.add_argument("--num_samples", type=int, default=10)
|
||||
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
|
||||
parser.add_argument("--cpu_max_memory", type=int, default=None)
|
||||
|
@ -277,8 +276,8 @@ def main():
|
|||
use_triton=args.use_triton,
|
||||
use_safetensors=args.use_safetensors,
|
||||
use_fast_tokenizer=args.use_fast_tokenizer,
|
||||
inject_fused_attention=not args.no_inject_fused_attention,
|
||||
inject_fused_mlp=not args.no_inject_fused_mlp,
|
||||
inject_fused_attention=args.inject_fused_attention,
|
||||
inject_fused_mlp=args.inject_fused_mlp,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
end = time.time()
|
||||
|
@ -289,7 +288,7 @@ def main():
|
|||
|
||||
if args.use_triton:
|
||||
logger.info("warmup triton, this may take a while.")
|
||||
model.warmup_triton()
|
||||
model.warmup_triton(model)
|
||||
|
||||
logger.info("loading data")
|
||||
examples = load_data(
|
||||
|
|
|
@ -37,12 +37,18 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||
parser.add_argument("--inject_fused_attention", action="store_true", help="Whether to inject fused attention")
|
||||
parser.add_argument("--inject_fused_mlp", action="store_true", help="Whether to inject fused mlp")
|
||||
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name,
|
||||
use_fast=args.use_fast_tokenizer,
|
||||
trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
|
@ -68,8 +74,8 @@ if __name__ == "__main__":
|
|||
model_basename=args.model_basename,
|
||||
use_safetensors=args.use_safetensors,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
inject_fused_mlp=False,
|
||||
inject_fused_attention=False,
|
||||
inject_fused_mlp=args.inject_fused_mlp,
|
||||
inject_fused_attention=args.inject_fused_attention,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
else:
|
||||
|
|
5
setup.py
5
setup.py
|
@ -68,7 +68,10 @@ requirements = [
|
|||
"datasets",
|
||||
"numpy",
|
||||
"rouge",
|
||||
"torch>=1.13.0",
|
||||
"torch>=2.0.1",
|
||||
"functorch",
|
||||
"xformers>=0.0.20",
|
||||
"vllm",
|
||||
"safetensors",
|
||||
"transformers>=4.31.0",
|
||||
"peft"
|
||||
|
|
Loading…
Add table
Reference in a new issue