Compare commits
36 commits
main
...
xformers_i
Author | SHA1 | Date | |
---|---|---|---|
|
d95661b250 | ||
|
0a04d3fb2a | ||
|
7c2ec905a6 | ||
|
b1c64d9269 | ||
|
eeee7b344f | ||
|
8fedbbf82d | ||
|
fdb8c4500a | ||
|
43b9a5cd0a | ||
|
efe47aafe5 | ||
|
3d09cf36d7 | ||
|
beab695c5b | ||
|
edc5b72da4 | ||
|
26dc6852fe | ||
|
d73ed1cfc2 | ||
|
e5f874e5af | ||
|
700406e6b6 | ||
|
2092a80b81 | ||
|
4aea0aef39 | ||
|
1f9717af7f | ||
|
7a70bcf6d8 | ||
|
57c3e5b7d5 | ||
|
01ce32553e | ||
|
677409e2fe | ||
|
9155ef3038 | ||
|
df24da5797 | ||
|
ab6faa6496 | ||
|
f67b512cee | ||
|
bacac399d3 | ||
|
c71f5cdf12 | ||
|
0fcfddda90 | ||
|
2826729e73 | ||
|
801610367d | ||
|
7d0909160c | ||
|
8b19122775 | ||
|
cd8a674002 | ||
|
116d8267d7 |
18 changed files with 1197 additions and 145 deletions
|
@ -17,11 +17,11 @@ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||||
from transformers.utils.generic import ContextManagers
|
from transformers.utils.generic import ContextManagers
|
||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
from xformers.ops.fmha import AttentionOp
|
||||||
|
|
||||||
from ._const import *
|
from ._const import *
|
||||||
from ._utils import *
|
from ._utils import *
|
||||||
from ..nn_modules.qlinear import GeneralQuantLinear
|
from ..nn_modules.qlinear import GeneralQuantLinear
|
||||||
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
|
||||||
from ..quantization import GPTQ
|
from ..quantization import GPTQ
|
||||||
from ..utils.data_utils import collate_data
|
from ..utils.data_utils import collate_data
|
||||||
from ..utils.import_utils import (
|
from ..utils.import_utils import (
|
||||||
|
@ -73,7 +73,7 @@ class BaseQuantizeConfig(PushToHubMixin):
|
||||||
quantize_config_filename = "quantize_config.json"
|
quantize_config_filename = "quantize_config.json"
|
||||||
if os.path.isdir(save_dir): # Local
|
if os.path.isdir(save_dir): # Local
|
||||||
resolved_config_file = join(save_dir, quantize_config_filename)
|
resolved_config_file = join(save_dir, quantize_config_filename)
|
||||||
else: # Remote
|
else: # Remote
|
||||||
resolved_config_file = cached_file(
|
resolved_config_file = cached_file(
|
||||||
save_dir,
|
save_dir,
|
||||||
quantize_config_filename,
|
quantize_config_filename,
|
||||||
|
@ -89,10 +89,9 @@ class BaseQuantizeConfig(PushToHubMixin):
|
||||||
_raise_exceptions_for_connection_errors=False,
|
_raise_exceptions_for_connection_errors=False,
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(resolved_config_file, "r", encoding="utf-8") as f:
|
with open(resolved_config_file, "r", encoding="utf-8") as f:
|
||||||
return cls(**json.load(f))
|
return cls(**json.load(f))
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"bits": self.bits,
|
"bits": self.bits,
|
||||||
|
@ -114,9 +113,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
inside_layer_modules: List[List[str]] = None
|
inside_layer_modules: List[List[str]] = None
|
||||||
lm_head_name: str = "lm_head"
|
lm_head_name: str = "lm_head"
|
||||||
|
|
||||||
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
|
||||||
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
|
@ -210,14 +206,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||||
if use_triton and not TRITON_AVAILABLE:
|
if use_triton and not TRITON_AVAILABLE:
|
||||||
logger.warning("triton is not installed, reset use_triton to False")
|
logger.warning("Triton is not installed, reset use_triton to False")
|
||||||
use_triton = False
|
use_triton = False
|
||||||
|
|
||||||
device_map = self.hf_device_map
|
device_map = self.hf_device_map
|
||||||
if device_map:
|
if device_map:
|
||||||
for name, device in device_map.items():
|
for name, device in device_map.items():
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
logger.info(f"truly offloading {name} to cpu with hook.")
|
logger.info(f"Truly offloading {name} to cpu with hook.")
|
||||||
module = get_module_by_name_suffix(self.model, name)
|
module = get_module_by_name_suffix(self.model, name)
|
||||||
remove_hook_from_module(module, recurse=True)
|
remove_hook_from_module(module, recurse=True)
|
||||||
accelerate.cpu_offload_with_hook(module, CUDA_0)
|
accelerate.cpu_offload_with_hook(module, CUDA_0)
|
||||||
|
@ -437,10 +433,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.model(*args, **kwargs)
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
def generate(self, **kwargs):
|
def generate(self, *args, **kwargs):
|
||||||
"""shortcut for model.generate"""
|
"""shortcut for model.generate"""
|
||||||
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
with torch.no_grad(), torch.amp.autocast(device_type=self.device.type):
|
||||||
return self.model.generate(**kwargs)
|
return self.model.generate(*args, **kwargs)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
"""shortcut for model.prepare_inputs_for_generation"""
|
"""shortcut for model.prepare_inputs_for_generation"""
|
||||||
|
@ -489,9 +485,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
create_pr (`bool`, *optional*, defaults to `False`):
|
create_pr (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to create a PR with the uploaded files or directly commit.
|
Whether or not to create a PR with the uploaded files or directly commit.
|
||||||
"""
|
"""
|
||||||
if (self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)) and save_dir is None:
|
if (
|
||||||
raise ValueError("Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading.")
|
self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)
|
||||||
|
) and save_dir is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Quantized model should be saved first, or you can provide save_dir to "
|
||||||
|
"make sure model is saved to local disk before uploading."
|
||||||
|
)
|
||||||
|
|
||||||
if save_dir is not None:
|
if save_dir is not None:
|
||||||
logger.info(f"Saving model to {save_dir}")
|
logger.info(f"Saving model to {save_dir}")
|
||||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||||
|
@ -517,16 +518,30 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
|
def save_quantized(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
safetensors_metadata: Optional[Dict[str, str]] = None
|
||||||
|
):
|
||||||
"""save quantized model and configs to local disk"""
|
"""save quantized model and configs to local disk"""
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
if not self.quantized:
|
if not self.quantized:
|
||||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
raise TypeError("Can only save quantized model, please execute .quantize() method first.")
|
||||||
|
if self.injected_fused_attention or self.injected_fused_mlp:
|
||||||
|
raise TypeError(
|
||||||
|
"At least one of attention modules and mlp modules are injected with fused ops, "
|
||||||
|
"please disable 'inject_fused_attention' and 'inject_fused_mlp' at model loading stage, "
|
||||||
|
"and don't call ._fuse_attention() and ._fuse_mlp() methods before calling this method."
|
||||||
|
)
|
||||||
|
|
||||||
self.model.to(CPU)
|
self.model.to(CPU)
|
||||||
|
|
||||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
model_base_name = (
|
||||||
|
self.quantize_config.model_file_base_name or
|
||||||
|
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||||
|
)
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
model_save_name = model_base_name + ".safetensors"
|
model_save_name = model_base_name + ".safetensors"
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
|
@ -546,13 +561,23 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
new_key = str(key)
|
new_key = str(key)
|
||||||
new_value = str(value)
|
new_value = str(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
raise TypeError(
|
||||||
|
f"safetensors_metadata: both keys and values must be strings and "
|
||||||
|
f"an error occured when trying to convert them: {e}"
|
||||||
|
)
|
||||||
if new_key in new_safetensors_metadata:
|
if new_key in new_safetensors_metadata:
|
||||||
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
logger.warning(
|
||||||
|
f"After converting safetensors_metadata keys to strings, the key "
|
||||||
|
f"'{new_key}' is duplicated. Ensure that all your metadata keys are "
|
||||||
|
f"strings to avoid overwriting."
|
||||||
|
)
|
||||||
new_safetensors_metadata[new_key] = new_value
|
new_safetensors_metadata[new_key] = new_value
|
||||||
safetensors_metadata = new_safetensors_metadata
|
safetensors_metadata = new_safetensors_metadata
|
||||||
if converted_keys:
|
if converted_keys:
|
||||||
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
logger.debug(
|
||||||
|
f"One or more safetensors_metadata keys or values had to be converted to str(). "
|
||||||
|
f"Final safetensors_metadata: {safetensors_metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
# Format is required to enable Accelerate to load the metadata
|
# Format is required to enable Accelerate to load the metadata
|
||||||
# otherwise it raises an OSError
|
# otherwise it raises an OSError
|
||||||
|
@ -576,9 +601,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
self.quantize_config.model_name_or_path = save_dir
|
self.quantize_config.model_name_or_path = save_dir
|
||||||
self.quantize_config.model_file_base_name = model_base_name
|
self.quantize_config.model_file_base_name = model_base_name
|
||||||
|
|
||||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""alias of save_quantized"""
|
"""alias of save_quantized"""
|
||||||
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
logger.warning("You are using save_pretrained, which will re-direct to save_quantized.")
|
||||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -672,9 +703,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
model.seqlen = model_config[key]
|
model.seqlen = model_config[key]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
|
||||||
model.seqlen = 4096
|
model.seqlen = 4096
|
||||||
model.eval()
|
|
||||||
|
|
||||||
return cls(model, False, quantize_config)
|
return cls(model, False, quantize_config)
|
||||||
|
|
||||||
|
@ -688,8 +718,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
low_cpu_mem_usage: bool = False,
|
low_cpu_mem_usage: bool = False,
|
||||||
use_triton: bool = False,
|
use_triton: bool = False,
|
||||||
torch_dtype: torch.dtype = torch.float16,
|
torch_dtype: torch.dtype = torch.float16,
|
||||||
inject_fused_attention: bool = True,
|
inject_fused_attention: bool = False,
|
||||||
inject_fused_mlp: bool = True,
|
inject_fused_mlp: bool = False,
|
||||||
use_cuda_fp16: bool = True,
|
use_cuda_fp16: bool = True,
|
||||||
quantize_config: Optional[BaseQuantizeConfig] = None,
|
quantize_config: Optional[BaseQuantizeConfig] = None,
|
||||||
model_basename: Optional[str] = None,
|
model_basename: Optional[str] = None,
|
||||||
|
@ -697,11 +727,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""load quantized model from local disk"""
|
"""load quantized model from local disk"""
|
||||||
|
|
||||||
# Parameters related to loading from Hugging Face Hub
|
# Parameters related to loading from Hugging Face Hub
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
@ -725,9 +756,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"_raise_exceptions_for_missing_entries": False,
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
"_commit_hash": commit_hash,
|
"_commit_hash": commit_hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
if use_triton and not TRITON_AVAILABLE:
|
if use_triton and not TRITON_AVAILABLE:
|
||||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
logger.warning("Triton is not installed, reset use_triton to False")
|
||||||
use_triton = False
|
use_triton = False
|
||||||
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
|
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -746,22 +777,33 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"2. You are using pytorch without CUDA support.\n"
|
"2. You are using pytorch without CUDA support.\n"
|
||||||
"3. CUDA and nvcc are not installed in your device."
|
"3. CUDA and nvcc are not installed in your device."
|
||||||
)
|
)
|
||||||
|
if any([inject_fused_attention, inject_fused_mlp]) and trainable:
|
||||||
|
logger.warning(
|
||||||
|
"Neither fused attention nor fused mlp is tested under trainable mode, "
|
||||||
|
"this may cause unexpected behavior or lead to error if you are training "
|
||||||
|
"a quantized model with fused ops, please consider disabling 'inject_fused_attention' "
|
||||||
|
"and 'inject_fused_mlp'."
|
||||||
|
)
|
||||||
|
|
||||||
# == step1: prepare configs and file names == #
|
# == step1: prepare configs and file names == #
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**cached_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if config.model_type not in SUPPORTED_MODELS:
|
if config.model_type not in SUPPORTED_MODELS:
|
||||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||||
|
|
||||||
if quantize_config is None:
|
if quantize_config is None:
|
||||||
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
|
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
|
||||||
|
|
||||||
if model_basename is None:
|
if model_basename is None:
|
||||||
if quantize_config.model_file_base_name:
|
if quantize_config.model_file_base_name:
|
||||||
model_basename = quantize_config.model_file_base_name
|
model_basename = quantize_config.model_file_base_name
|
||||||
else:
|
else:
|
||||||
model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g"
|
model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g"
|
||||||
|
|
||||||
quantize_config.model_name_or_path = model_name_or_path
|
quantize_config.model_name_or_path = model_name_or_path
|
||||||
quantize_config.model_file_base_name = model_basename
|
quantize_config.model_file_base_name = model_basename
|
||||||
|
|
||||||
|
@ -786,18 +828,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
if resolved_archive_file is None: # Could not find a model file to use
|
if resolved_archive_file is None: # Could not find a model file to use
|
||||||
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||||
|
|
||||||
model_save_name = resolved_archive_file
|
model_save_name = resolved_archive_file
|
||||||
|
|
||||||
if not disable_exllama and trainable:
|
if not disable_exllama and trainable:
|
||||||
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||||
disable_exllama = True
|
disable_exllama = True
|
||||||
|
|
||||||
elif not use_triton and trainable:
|
elif not use_triton and trainable:
|
||||||
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
logger.warning(
|
||||||
|
"QuantLinear with cuda backend not support trainable mode yet, will switch to pytorch backend, "
|
||||||
|
"this may cause very slow inference speed, disable trainable if you are not training model."
|
||||||
|
)
|
||||||
|
|
||||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||||
def skip(*args, **kwargs):
|
def skip(*args, **kwargs):
|
||||||
|
@ -881,7 +925,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
)
|
)
|
||||||
model = simple_dispatch_model(model, device_map)
|
model = simple_dispatch_model(model, device_map)
|
||||||
|
|
||||||
# == step4: set seqlen == #
|
# == step4: post init model == #
|
||||||
|
# Any post-initialization that require device information, for example buffers initialization on device.
|
||||||
|
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
|
||||||
|
|
||||||
|
# == step5: set seqlen == #
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
if any([k in model_config for k in seq_len_keys]):
|
if any([k in model_config for k in seq_len_keys]):
|
||||||
|
@ -890,50 +938,62 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
model.seqlen = model_config[key]
|
model.seqlen = model_config[key]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
logger.warning("Can't get model's sequence length from model config, will set to 4096.")
|
||||||
model.seqlen = 4096
|
model.seqlen = 4096
|
||||||
|
|
||||||
# == step5: (optional) inject optimized module == #
|
# == step6: (optional) inject optimized module == #
|
||||||
if inject_fused_attention:
|
if inject_fused_attention:
|
||||||
if cls.fused_attn_module_type is None:
|
try:
|
||||||
|
cls._fuse_attention(model, attn_op, trainable)
|
||||||
|
except NotImplementedError:
|
||||||
inject_fused_attention = False
|
inject_fused_attention = False
|
||||||
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
logger.warning(
|
||||||
else:
|
f"{cls.__name__} doesn't support fusing attention yet, will skip inject fused attention."
|
||||||
cls.fused_attn_module_type.inject_to_model(
|
|
||||||
model,
|
|
||||||
use_triton=use_triton,
|
|
||||||
group_size=quantize_config.group_size,
|
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
|
||||||
desc_act=quantize_config.desc_act,
|
|
||||||
trainable=trainable,
|
|
||||||
bits=quantize_config.bits,
|
|
||||||
disable_exllama=disable_exllama,
|
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
logger.error(
|
||||||
|
f"Inject fused attention failed, you can set 'inject_fused_attention' to False to "
|
||||||
|
f"bypass the error for now and report it on github."
|
||||||
|
)
|
||||||
|
raise
|
||||||
if inject_fused_mlp:
|
if inject_fused_mlp:
|
||||||
if cls.fused_mlp_module_type is None:
|
try:
|
||||||
|
cls._fuse_mlp(model, trainable)
|
||||||
|
except NotImplementedError:
|
||||||
inject_fused_mlp = False
|
inject_fused_mlp = False
|
||||||
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
logger.warning(
|
||||||
else:
|
f"{cls.__name__} doesn't support fusing mlp yet, will skip inject fused mlp."
|
||||||
cls.fused_mlp_module_type.inject_to_model(
|
|
||||||
model,
|
|
||||||
use_triton=use_triton
|
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
logger.error(
|
||||||
|
f"Inject fused mlp failed, you can set 'inject_fused_mlp' to False to "
|
||||||
|
f"bypass the error for now and report it on github."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
if inject_fused_attention or inject_fused_mlp:
|
||||||
|
logger.warning(
|
||||||
|
"You are using at least one of 'inject_fused_attention' and 'inject_fused_mlp' "
|
||||||
|
"modes, which are now marked as experimental features, feel free to open an issue "
|
||||||
|
"or ask any question about those two features on github if you encounter unexpected "
|
||||||
|
"behaviors and errors."
|
||||||
|
)
|
||||||
|
|
||||||
# Any post-initialization that require device information, for example buffers initialization on device.
|
# == step7: (optional) warmup triton == #
|
||||||
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
# == step6: (optional) warmup triton == #
|
|
||||||
if use_triton and warmup_triton:
|
if use_triton and warmup_triton:
|
||||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
cls.warmup_triton(model)
|
||||||
QuantLinear.warmup(model, seqlen=model.seqlen)
|
|
||||||
|
|
||||||
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
# == step8: convert all QuantLinear to sub-class of torch.nn.Linear
|
||||||
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
# note if _fuse_attention() and _fuse_mlp() is implemented,
|
||||||
|
# all QuantLinear will be converted to sub-class of torch.nn.Linear at injection stage
|
||||||
# == step7: make model compatible with peft
|
GeneralQuantLinear.convert_to_torch_linear(
|
||||||
cls.make_sure_compatible_with_peft(
|
model,
|
||||||
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits
|
dynamically_import_QuantLinear(
|
||||||
|
use_triton,
|
||||||
|
quantize_config.desc_act,
|
||||||
|
quantize_config.group_size,
|
||||||
|
quantize_config.bits,
|
||||||
|
disable_exllama
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -942,39 +1002,35 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
quantize_config,
|
quantize_config,
|
||||||
is_triton_backend=use_triton,
|
is_triton_backend=use_triton,
|
||||||
injected_fused_attention=inject_fused_attention,
|
injected_fused_attention=inject_fused_attention,
|
||||||
injected_fused_mlp=inject_fused_mlp and use_triton,
|
injected_fused_mlp=inject_fused_mlp,
|
||||||
trainable=trainable
|
trainable=trainable
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_triton(self, enabled: bool = True):
|
@staticmethod
|
||||||
|
def _fuse_attention(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fuse_mlp(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def warmup_triton(model: nn.Module, enabled: bool = True) -> None:
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return
|
return
|
||||||
if not TRITON_AVAILABLE:
|
if not TRITON_AVAILABLE:
|
||||||
logger.warning(f"triton is not available, skip warmup stage directly.")
|
logger.warning(f"Triton is not available, skip warmup stage directly.")
|
||||||
return
|
return
|
||||||
|
|
||||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||||
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
|
QuantLinear.warmup(model, seqlen=model.seqlen)
|
||||||
|
|
||||||
if self.fused_mlp_module_type is not None:
|
|
||||||
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
|
||||||
|
|
||||||
def enable_trainable_mode(self, enabled: bool = True):
|
|
||||||
if not self.is_triton_backend and enabled:
|
|
||||||
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
|
||||||
for n, m in self.model.named_modules():
|
|
||||||
if hasattr(m, "trainable"):
|
|
||||||
setattr(m, "trainable", enabled)
|
|
||||||
|
|
||||||
def disable_trainable_mode(self):
|
|
||||||
self.enable_trainable_mode(enabled=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
|
|
||||||
GeneralQuantLinear.inject_to_model(
|
|
||||||
model,
|
|
||||||
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,9 +1,5 @@
|
||||||
from packaging.version import parse as parse_version
|
|
||||||
|
|
||||||
from torch import device
|
from torch import device
|
||||||
|
|
||||||
from ..utils.import_utils import compare_transformers_version
|
|
||||||
|
|
||||||
CPU = device("cpu")
|
CPU = device("cpu")
|
||||||
CUDA_0 = device("cuda:0")
|
CUDA_0 = device("cuda:0")
|
||||||
|
|
||||||
|
@ -20,9 +16,8 @@ SUPPORTED_MODELS = [
|
||||||
"RefinedWeb",
|
"RefinedWeb",
|
||||||
"baichuan",
|
"baichuan",
|
||||||
"internlm",
|
"internlm",
|
||||||
|
"llama",
|
||||||
"qwen",
|
"qwen",
|
||||||
]
|
]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
|
||||||
SUPPORTED_MODELS.append("llama")
|
|
||||||
|
|
||||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
|
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from xformers.ops.fmha import AttentionOp
|
||||||
|
|
||||||
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
|
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
|
||||||
from ._utils import check_and_get_model_type
|
from ._utils import check_and_get_model_type
|
||||||
from .bloom import BloomGPTQForCausalLM
|
from .bloom import BloomGPTQForCausalLM
|
||||||
|
@ -81,6 +83,7 @@ class AutoGPTQForCausalLM:
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
|
@ -121,6 +124,7 @@ class AutoGPTQForCausalLM:
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
warmup_triton=warmup_triton,
|
warmup_triton=warmup_triton,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
|
attn_op=attn_op,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
**keywords
|
**keywords
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,19 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import xformers.ops as xop
|
||||||
|
from torch.cuda import empty_cache
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from xformers.ops.fmha import AttentionOp
|
||||||
|
|
||||||
from ._base import *
|
from ._base import *
|
||||||
|
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
|
||||||
|
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||||
|
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanFusedAttentionWithRope(FusedAttentionWithRoPE):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
@ -12,5 +27,49 @@ class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["mlp.down_proj"]
|
["mlp.down_proj"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fuse_attention(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
model_config = model.config
|
||||||
|
num_heads = model_config.num_attention_heads
|
||||||
|
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||||
|
layers = model.model.layers
|
||||||
|
|
||||||
|
rope_cache = build_rope_cache(
|
||||||
|
rotary_dim=model_config.hidden_size // num_heads,
|
||||||
|
max_position=model_config.max_position_embeddings,
|
||||||
|
device=model.device,
|
||||||
|
dtype=model.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
old_attn = layer.self_attn
|
||||||
|
attn_device = old_attn.W_pack.qweight.data.device
|
||||||
|
new_qkv_proj = FusedGeneralQuantLinear(old_attn.W_pack)
|
||||||
|
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
|
||||||
|
new_attn = BaiChuanFusedAttentionWithRope(
|
||||||
|
qkv_proj=new_qkv_proj,
|
||||||
|
out_proj=new_out_proj,
|
||||||
|
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
|
||||||
|
num_query_heads=num_heads,
|
||||||
|
num_key_heads=num_heads,
|
||||||
|
num_value_heads=num_heads,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
resid_dropout=0.0,
|
||||||
|
scale=scale,
|
||||||
|
attention_ops=attn_op,
|
||||||
|
outputs_handler=(lambda x, y, z: (x, z, y)),
|
||||||
|
training=trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.self_attn = new_attn
|
||||||
|
|
||||||
|
del old_attn
|
||||||
|
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BaiChuanGPTQForCausalLM"]
|
__all__ = ["BaiChuanGPTQForCausalLM"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
from ._base import BaseGPTQForCausalLM
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
@ -14,4 +14,5 @@ class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["mlp.c_proj"]
|
["mlp.c_proj"]
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__ = ["GPTBigCodeGPTQForCausalLM"]
|
|
||||||
|
__all__ = ["GPTBigCodeGPTQForCausalLM"]
|
||||||
|
|
|
@ -1,5 +1,136 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import xformers.ops as xop
|
||||||
|
from torch.cuda import empty_cache
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb
|
||||||
|
from xformers.ops.fmha import AttentionOp
|
||||||
|
|
||||||
from ._base import *
|
from ._base import *
|
||||||
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
|
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||||
|
from ..nn_modules.fused_modules.attention import FusedAttention
|
||||||
|
from ..nn_modules.fused_modules.mlp import FusedMLP
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJFusedAttention(FusedAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qkv_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
embed_positions: torch.Tensor,
|
||||||
|
rotary_dim: Optional[int],
|
||||||
|
num_query_heads: int,
|
||||||
|
num_key_heads: int,
|
||||||
|
num_value_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
resid_dropout: float = 0.0,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
attention_ops: Optional[xop.AttentionOp] = None,
|
||||||
|
outputs_handler: Optional[Callable] = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
super(GPTJFusedAttention, self).__init__(
|
||||||
|
qkv_proj,
|
||||||
|
out_proj,
|
||||||
|
num_query_heads,
|
||||||
|
num_key_heads,
|
||||||
|
num_value_heads,
|
||||||
|
attn_dropout,
|
||||||
|
resid_dropout,
|
||||||
|
scale,
|
||||||
|
attention_ops,
|
||||||
|
outputs_handler,
|
||||||
|
training
|
||||||
|
)
|
||||||
|
self.embed_positions = embed_positions
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
|
||||||
|
def _get_embed_positions(self, position_ids: torch.Tensor):
|
||||||
|
return self.embed_positions.repeat(position_ids.shape[0], 1, 1)
|
||||||
|
|
||||||
|
def _apply_rotary(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
bsz, seq_len = key.shape[:2]
|
||||||
|
|
||||||
|
dtype = query.dtype
|
||||||
|
query = query.view(bsz, seq_len, self.num_query_heads, -1).to(dtype=torch.float)
|
||||||
|
key = key.view(bsz, seq_len, self.num_key_heads, -1).to(dtype=torch.float)
|
||||||
|
|
||||||
|
embed_positions = self._get_embed_positions(position_ids)
|
||||||
|
|
||||||
|
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
|
||||||
|
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
|
||||||
|
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
||||||
|
|
||||||
|
if self.rotary_dim is not None:
|
||||||
|
k_rot = key[:, :, :, : self.rotary_dim]
|
||||||
|
k_pass = key[:, :, :, self.rotary_dim:]
|
||||||
|
|
||||||
|
q_rot = query[:, :, :, : self.rotary_dim]
|
||||||
|
q_pass = query[:, :, :, self.rotary_dim:]
|
||||||
|
|
||||||
|
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
||||||
|
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
||||||
|
|
||||||
|
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||||
|
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||||
|
else:
|
||||||
|
key = apply_rotary_pos_emb(key, sin, cos)
|
||||||
|
query = apply_rotary_pos_emb(query, sin, cos)
|
||||||
|
|
||||||
|
return query.view(bsz, seq_len, -1).to(dtype=dtype), key.view(bsz, seq_len, -1).to(dtype=dtype)
|
||||||
|
|
||||||
|
def _build_attn_bias(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> Optional[xop.AttentionBias]:
|
||||||
|
return xop.LowerTriangularMask()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
bsz, seq_len = hidden_states.shape[:2]
|
||||||
|
|
||||||
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
if position_ids is not None:
|
||||||
|
q, k = self._apply_rotary(q, k, position_ids)
|
||||||
|
|
||||||
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if layer_past is None else None
|
||||||
|
attn_out, present = self._attn(
|
||||||
|
bsz,
|
||||||
|
seq_len,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_bias,
|
||||||
|
use_cache,
|
||||||
|
layer_past
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(attn_out)
|
||||||
|
out = self.resid_dropout(out)
|
||||||
|
|
||||||
|
outputs = (out, present, None)
|
||||||
|
if self.outputs_handler:
|
||||||
|
outputs = self.outputs_handler(*outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
@ -13,7 +144,75 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["mlp.fc_out"]
|
["mlp.fc_out"]
|
||||||
]
|
]
|
||||||
|
|
||||||
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel
|
@staticmethod
|
||||||
|
def _fuse_attention(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
model_config = model.config
|
||||||
|
num_heads = model_config.n_head
|
||||||
|
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||||
|
layers = model.transformer.h
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
old_attn = layer.attn
|
||||||
|
device = old_attn.q_proj.qweight.data.device
|
||||||
|
new_qkv_proj = FusedGeneralQuantLinear.fuse(
|
||||||
|
old_attn.q_proj,
|
||||||
|
old_attn.k_proj,
|
||||||
|
old_attn.v_proj
|
||||||
|
)
|
||||||
|
new_out_proj = FusedGeneralQuantLinear(old_attn.out_proj)
|
||||||
|
new_attn = GPTJFusedAttention(
|
||||||
|
qkv_proj=new_qkv_proj,
|
||||||
|
out_proj=new_out_proj,
|
||||||
|
embed_positions=old_attn.embed_positions.to(device),
|
||||||
|
rotary_dim=old_attn.rotary_dim,
|
||||||
|
num_query_heads=num_heads,
|
||||||
|
num_key_heads=num_heads,
|
||||||
|
num_value_heads=num_heads,
|
||||||
|
attn_dropout=model_config.attn_pdrop,
|
||||||
|
resid_dropout=model_config.resid_pdrop,
|
||||||
|
scale=scale,
|
||||||
|
attention_ops=attn_op,
|
||||||
|
outputs_handler=None,
|
||||||
|
training=trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.attn = new_attn
|
||||||
|
|
||||||
|
del old_attn
|
||||||
|
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fuse_mlp(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
model_config = model.config
|
||||||
|
act = ACT2FN[model_config.activation_function]
|
||||||
|
out_dropout = model_config.resid_pdrop
|
||||||
|
layers = model.transformer.h
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
old_mlp = layer.mlp
|
||||||
|
new_mlp = FusedMLP(
|
||||||
|
input_proj=FusedGeneralQuantLinear(old_mlp.fc_in),
|
||||||
|
out_proj=FusedGeneralQuantLinear(old_mlp.fc_out),
|
||||||
|
activation=act,
|
||||||
|
in_dropout=0.0,
|
||||||
|
out_dropout=out_dropout,
|
||||||
|
training=trainable,
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.mlp = new_mlp
|
||||||
|
|
||||||
|
del old_mlp
|
||||||
|
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["GPTJGPTQForCausalLM"]
|
__all__ = ["GPTJGPTQForCausalLM"]
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
from logging import getLogger
|
from copy import deepcopy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.cuda import empty_cache
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from xformers.ops.fmha import AttentionOp
|
||||||
|
|
||||||
from ._base import *
|
from ._base import *
|
||||||
from ..utils.import_utils import compare_transformers_version
|
from ..nn_modules.fused_modules.attention import build_rope_cache, FusedAttentionWithRoPE
|
||||||
|
from ..nn_modules.fused_modules.linear import FusedGeneralQuantLinear
|
||||||
|
from ..nn_modules.fused_modules.mlp import FusedGatedMLP
|
||||||
|
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
|
||||||
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
|
|
||||||
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
|
|
||||||
else:
|
|
||||||
FusedLlamaAttentionForQuantizedModel = None
|
|
||||||
FusedLlamaMLPForQuantizedModel = None
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
class LlamaFusedAttentionWithRoPE(FusedAttentionWithRoPE):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
@ -24,8 +26,61 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["mlp.down_proj"]
|
["mlp.down_proj"]
|
||||||
]
|
]
|
||||||
|
|
||||||
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
|
@staticmethod
|
||||||
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
|
def _fuse_attention(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
attn_op: Optional[AttentionOp] = None,
|
||||||
|
trainable: bool = False
|
||||||
|
) -> None:
|
||||||
|
model_config = model.config
|
||||||
|
num_heads = model_config.num_attention_heads
|
||||||
|
scale = (model_config.hidden_size // num_heads) ** -0.5
|
||||||
|
layers = model.model.layers
|
||||||
|
|
||||||
|
rope_cache = build_rope_cache(
|
||||||
|
rotary_dim=model_config.hidden_size // num_heads,
|
||||||
|
max_position=model_config.max_position_embeddings,
|
||||||
|
base=10000,
|
||||||
|
device=model.device,
|
||||||
|
dtype=model.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
old_attn = layer.self_attn
|
||||||
|
attn_device = old_attn.q_proj.qweight.data.device
|
||||||
|
new_qkv_proj = FusedGeneralQuantLinear.fuse(
|
||||||
|
old_attn.q_proj,
|
||||||
|
old_attn.k_proj,
|
||||||
|
old_attn.v_proj
|
||||||
|
)
|
||||||
|
new_out_proj = FusedGeneralQuantLinear(old_attn.o_proj)
|
||||||
|
new_attn = LlamaFusedAttentionWithRoPE(
|
||||||
|
qkv_proj=new_qkv_proj,
|
||||||
|
out_proj=new_out_proj,
|
||||||
|
cos_sin_cache=rope_cache if attn_device == model.device else deepcopy(rope_cache).to(attn_device),
|
||||||
|
num_query_heads=num_heads,
|
||||||
|
num_key_heads=num_heads,
|
||||||
|
num_value_heads=num_heads,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
resid_dropout=0.0,
|
||||||
|
scale=scale,
|
||||||
|
attention_ops=attn_op,
|
||||||
|
outputs_handler=(lambda x, y, z: (x, z, y)),
|
||||||
|
training=trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.self_attn = new_attn
|
||||||
|
|
||||||
|
del old_attn
|
||||||
|
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def _fuse_mlp(
|
||||||
|
# model: PreTrainedModel,
|
||||||
|
# trainable: bool = False
|
||||||
|
# ) -> None:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LlamaGPTQForCausalLM"]
|
__all__ = ["LlamaGPTQForCausalLM"]
|
||||||
|
|
0
auto_gptq/nn_modules/fused_modules/__init__.py
Normal file
0
auto_gptq/nn_modules/fused_modules/__init__.py
Normal file
394
auto_gptq/nn_modules/fused_modules/attention.py
Normal file
394
auto_gptq/nn_modules/fused_modules/attention.py
Normal file
|
@ -0,0 +1,394 @@
|
||||||
|
import math
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import xformers.ops as xop
|
||||||
|
from vllm import pos_encoding_ops as vllm_pos_encoding_ops
|
||||||
|
from xformers.ops.fmha.attn_bias import LowerTriangularMask, LowerTriangularMaskWithTensorBias
|
||||||
|
|
||||||
|
|
||||||
|
POTENTIAL_KV_CACHE_NAMES = (
|
||||||
|
"past_key_value",
|
||||||
|
"layer_past",
|
||||||
|
"kv_cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_to_get_kv_cache(**kwargs) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
kv_cache = None
|
||||||
|
for name in POTENTIAL_KV_CACHE_NAMES:
|
||||||
|
if name in kwargs:
|
||||||
|
return kwargs[name]
|
||||||
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def build_rope_cache(
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position: int = 2048,
|
||||||
|
base: int = 10000,
|
||||||
|
device: torch.device = torch.device("cuda:0"),
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
): # TODO: support multiple scaling strategies
|
||||||
|
inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
|
||||||
|
t = torch.arange(max_position, device=device, dtype=dtype)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def build_alibi_slopes(
|
||||||
|
num_heads: int,
|
||||||
|
device: torch.device = torch.device("cuda:0"),
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
):
|
||||||
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||||
|
base = torch.tensor(
|
||||||
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
||||||
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
|
||||||
|
slopes = slopes.to(dtype)
|
||||||
|
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_ops: Optional[xop.AttentionOp] = (xop.fmha.flash.FwOp(), None),
|
||||||
|
attention_bias: Optional[xop.AttentionBias] = None,
|
||||||
|
p: float = 0.0,
|
||||||
|
scale: Optional[float] = None
|
||||||
|
):
|
||||||
|
if value.shape[2] != query.shape[2]:
|
||||||
|
# MQA expand
|
||||||
|
if value.shape[2] == 1:
|
||||||
|
pass # TODO
|
||||||
|
# GQA reshape
|
||||||
|
else:
|
||||||
|
original_shape = value.shape
|
||||||
|
pass # TODO
|
||||||
|
|
||||||
|
return xop.memory_efficient_attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
attn_bias=attention_bias,
|
||||||
|
p=p,
|
||||||
|
scale=scale,
|
||||||
|
op=attention_ops
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qkv_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
num_query_heads: int,
|
||||||
|
num_key_heads: int,
|
||||||
|
num_value_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
resid_dropout: float = 0.0,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
attention_ops: Optional[xop.AttentionOp] = None,
|
||||||
|
outputs_handler: Optional[Callable] = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
super(FusedAttention, self).__init__()
|
||||||
|
|
||||||
|
self.qkv_proj = qkv_proj
|
||||||
|
self.out_proj = out_proj
|
||||||
|
|
||||||
|
self.num_query_heads = num_query_heads
|
||||||
|
self.num_key_heads = num_key_heads
|
||||||
|
self.num_value_heads = num_value_heads
|
||||||
|
|
||||||
|
self.attn_dropout = attn_dropout if training else 0.0
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.attention_ops = attention_ops
|
||||||
|
|
||||||
|
self.outputs_handler = outputs_handler
|
||||||
|
|
||||||
|
self.resid_dropout = nn.Dropout(resid_dropout if training else 0.0)
|
||||||
|
|
||||||
|
def _build_attn_bias(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> Optional[xop.AttentionBias]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _attn(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
attention_bias: Optional[xop.AttentionBias] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
):
|
||||||
|
q = q.view(batch_size, seq_len, self.num_query_heads, -1).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, seq_len, self.num_key_heads, -1).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, seq_len, self.num_value_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
k = torch.cat((k_cache, k), dim=2)
|
||||||
|
v = torch.cat((v_cache, v), dim=2)
|
||||||
|
|
||||||
|
present = None
|
||||||
|
if use_cache:
|
||||||
|
present = (k, v)
|
||||||
|
|
||||||
|
attn_out = attention(
|
||||||
|
query=q.transpose(1, 2),
|
||||||
|
key=k.transpose(1, 2),
|
||||||
|
value=v.transpose(1, 2),
|
||||||
|
attention_ops=self.attention_ops,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
p=self.attn_dropout,
|
||||||
|
scale=self.scale
|
||||||
|
).view(batch_size, seq_len, -1)
|
||||||
|
|
||||||
|
return attn_out, present
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
bsz, seq_len = hidden_states.shape[:2]
|
||||||
|
use_cache = kwargs.get("use_cache", False)
|
||||||
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||||
|
|
||||||
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask)
|
||||||
|
attn_out, present = self._attn(
|
||||||
|
bsz,
|
||||||
|
seq_len,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_bias,
|
||||||
|
use_cache,
|
||||||
|
kv_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(attn_out)
|
||||||
|
out = self.resid_dropout(out)
|
||||||
|
|
||||||
|
outputs = (out, present, None)
|
||||||
|
if self.outputs_handler:
|
||||||
|
outputs = self.outputs_handler(*outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttentionWithRoPE(FusedAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qkv_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
cos_sin_cache: torch.Tensor,
|
||||||
|
num_query_heads: int,
|
||||||
|
num_key_heads: int,
|
||||||
|
num_value_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
resid_dropout: float = 0.0,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
attention_ops: Optional[xop.AttentionOp] = None,
|
||||||
|
outputs_handler: Optional[Callable] = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
super(FusedAttentionWithRoPE, self).__init__(
|
||||||
|
qkv_proj=qkv_proj,
|
||||||
|
out_proj=out_proj,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_key_heads=num_key_heads,
|
||||||
|
num_value_heads=num_value_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
resid_dropout=resid_dropout,
|
||||||
|
scale=scale,
|
||||||
|
attention_ops=attention_ops,
|
||||||
|
outputs_handler=outputs_handler,
|
||||||
|
training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
|
||||||
|
|
||||||
|
def _build_attn_bias(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> Optional[xop.AttentionBias]:
|
||||||
|
return LowerTriangularMask()
|
||||||
|
|
||||||
|
def _apply_rotary_embedding(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
bsz, seq_len, hidden_size = query.shape
|
||||||
|
|
||||||
|
if position_ids is not None:
|
||||||
|
query = query.view(bsz * seq_len, -1)
|
||||||
|
key = key.view(bsz * seq_len, -1)
|
||||||
|
vllm_pos_encoding_ops.rotary_embedding_neox(
|
||||||
|
position_ids.view(-1).to(query.device),
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
hidden_size // self.num_query_heads,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
)
|
||||||
|
query = query.view(bsz, seq_len, -1)
|
||||||
|
key = key.view(bsz, seq_len, -1)
|
||||||
|
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
bsz, seq_len = hidden_states.shape[:2]
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
use_cache = kwargs.get("use_cache", False)
|
||||||
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||||
|
|
||||||
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
q, k = self._apply_rotary_embedding(q, k, position_ids)
|
||||||
|
|
||||||
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
||||||
|
attn_out, present = self._attn(
|
||||||
|
bsz,
|
||||||
|
seq_len,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_bias,
|
||||||
|
use_cache,
|
||||||
|
kv_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(attn_out)
|
||||||
|
out = self.resid_dropout(out)
|
||||||
|
|
||||||
|
outputs = (out, present, None)
|
||||||
|
if self.outputs_handler:
|
||||||
|
outputs = self.outputs_handler(*outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttentionWithALiBi(FusedAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qkv_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
alibi_slopes: torch.Tensor,
|
||||||
|
num_query_heads: int,
|
||||||
|
num_key_heads: int,
|
||||||
|
num_value_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
resid_dropout: float = 0.0,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
attention_ops: Optional[xop.AttentionOp] = None,
|
||||||
|
outputs_handler: Optional[Callable] = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
super(FusedAttentionWithALiBi, self).__init__(
|
||||||
|
qkv_proj=qkv_proj,
|
||||||
|
out_proj=out_proj,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_key_heads=num_key_heads,
|
||||||
|
num_value_heads=num_value_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
resid_dropout=resid_dropout,
|
||||||
|
scale=scale,
|
||||||
|
attention_ops=attention_ops,
|
||||||
|
outputs_handler=outputs_handler,
|
||||||
|
training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||||
|
|
||||||
|
def _build_attn_bias(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> Optional[xop.AttentionBias]: # adopt from vllm
|
||||||
|
bsz, seq_len = hidden_states.shape[:2]
|
||||||
|
|
||||||
|
bias = torch.arange(seq_len)
|
||||||
|
bias = bias[None, :] - bias[:, None]
|
||||||
|
bias = bias.to(hidden_states.device)
|
||||||
|
|
||||||
|
# When using custom attention bias, xformers requires the bias to
|
||||||
|
# be sliced from a tensor whose length is a multiple of 8.
|
||||||
|
padded_len = (seq_len + 7) // 8 * 8
|
||||||
|
bias = torch.empty(
|
||||||
|
self.num_query_heads,
|
||||||
|
padded_len,
|
||||||
|
padded_len,
|
||||||
|
device=self.alibi_slopes.device,
|
||||||
|
)[:, :seq_len, :seq_len].copy_(bias)
|
||||||
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
|
bias = LowerTriangularMaskWithTensorBias(bias.unsqueeze(0).repeat(bsz, 1, 1, 1))
|
||||||
|
|
||||||
|
return bias
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
bsz, seq_len = hidden_states.shape[:2]
|
||||||
|
use_cache = kwargs.get("use_cache", False)
|
||||||
|
kv_cache = _try_to_get_kv_cache(**kwargs)
|
||||||
|
|
||||||
|
q, k, v = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
attn_bias = self._build_attn_bias(hidden_states, attention_mask) if kv_cache is None else None
|
||||||
|
attn_out, present = self._attn(
|
||||||
|
bsz,
|
||||||
|
seq_len,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_bias,
|
||||||
|
use_cache,
|
||||||
|
kv_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(attn_out)
|
||||||
|
out = self.resid_dropout(out)
|
||||||
|
|
||||||
|
outputs = (out, present, None)
|
||||||
|
if self.outputs_handler:
|
||||||
|
outputs = self.outputs_handler(*outputs)
|
||||||
|
|
||||||
|
return outputs
|
97
auto_gptq/nn_modules/fused_modules/linear.py
Normal file
97
auto_gptq/nn_modules/fused_modules/linear.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..qlinear import GeneralQuantLinear
|
||||||
|
from ..qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
|
||||||
|
from ..qlinear.qlinear_cuda_old import QuantLinear as OldCudaQuantLinear
|
||||||
|
try:
|
||||||
|
from ..qlinear.qlinear_triton import QuantLinear as TritonQuantLinear
|
||||||
|
except:
|
||||||
|
TritonQuantLinear = None
|
||||||
|
try:
|
||||||
|
from ..qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
|
||||||
|
except:
|
||||||
|
ExllamaQuantLinear = None
|
||||||
|
|
||||||
|
|
||||||
|
class FusedGeneralQuantLinear(GeneralQuantLinear):
|
||||||
|
def __init__(self, quant_linear_module):
|
||||||
|
super(FusedGeneralQuantLinear, self).__init__(quant_linear_module)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fuse(
|
||||||
|
cls,
|
||||||
|
q_proj,
|
||||||
|
k_proj=None,
|
||||||
|
v_proj=None,
|
||||||
|
):
|
||||||
|
qweights, qzeros, scales, g_idx, bias = [], [], [], [], []
|
||||||
|
outfeatures = 0
|
||||||
|
for module in [q_proj, k_proj, v_proj]:
|
||||||
|
if module is not None:
|
||||||
|
qweights.append(module.qweight)
|
||||||
|
qzeros.append(module.qzeros)
|
||||||
|
scales.append(module.scales)
|
||||||
|
g_idx.append(module.g_idx)
|
||||||
|
bias.append(module.bias)
|
||||||
|
outfeatures += module.outfeatures
|
||||||
|
|
||||||
|
if bias[0] is None:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
if len(qweights) > 1:
|
||||||
|
qweights = torch.cat(qweights, dim=1)
|
||||||
|
qzeros = torch.cat(qzeros, dim=1)
|
||||||
|
scales = torch.cat(scales, dim=1)
|
||||||
|
g_idx = torch.cat(g_idx, dim=0)
|
||||||
|
if bias is not None:
|
||||||
|
bias = torch.cat(bias, dim=0)
|
||||||
|
|
||||||
|
qlinear_args = (
|
||||||
|
q_proj.bits,
|
||||||
|
q_proj.group_size,
|
||||||
|
q_proj.infeatures,
|
||||||
|
outfeatures,
|
||||||
|
bias is not None
|
||||||
|
)
|
||||||
|
qlinear_kwargs = {"trainable": q_proj.trainable}
|
||||||
|
if isinstance(q_proj, (OldCudaQuantLinear, CudaQuantLinear)):
|
||||||
|
qlinear_kwargs["kernel_switch_threshold"] = q_proj.kernel_switch_threshold
|
||||||
|
if isinstance(q_proj, OldCudaQuantLinear):
|
||||||
|
qlinear_kwargs["use_cuda_fp16"] = q_proj.use_cuda_fp16
|
||||||
|
QuantLinear = OldCudaQuantLinear
|
||||||
|
else:
|
||||||
|
QuantLinear = CudaQuantLinear
|
||||||
|
elif isinstance(q_proj, TritonQuantLinear):
|
||||||
|
QuantLinear = TritonQuantLinear
|
||||||
|
else:
|
||||||
|
QuantLinear = ExllamaQuantLinear
|
||||||
|
fused_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||||
|
|
||||||
|
fused_proj.qweight = qweights
|
||||||
|
fused_proj.qzeros = qzeros
|
||||||
|
fused_proj.scales = scales
|
||||||
|
fused_proj.g_idx = g_idx
|
||||||
|
fused_proj.bias = bias
|
||||||
|
|
||||||
|
if isinstance(q_proj, ExllamaQuantLinear):
|
||||||
|
if not hasattr(q_proj, "_use_act_order"):
|
||||||
|
raise AttributeError(
|
||||||
|
"q_proj doesn't have attribute _use_act_order, please execute "
|
||||||
|
"auto_gptq.modeling._utils.autogptq_post_init function before "
|
||||||
|
"fuse quant linears."
|
||||||
|
)
|
||||||
|
if q_proj._use_act_order:
|
||||||
|
# TODO: support it. The issue lies maybe in the line:
|
||||||
|
# int groups = qzeros.size(0);
|
||||||
|
# in exllama_ext.cpp
|
||||||
|
raise ValueError(
|
||||||
|
"Exllama kernel does not support layer fusion with act-order. "
|
||||||
|
"Please either use inject_fused_attention=False or disable_exllama=True."
|
||||||
|
)
|
||||||
|
fused_proj._use_act_order = q_proj._use_act_order
|
||||||
|
fused_proj.g_idx = None
|
||||||
|
fused_proj.post_init()
|
||||||
|
|
||||||
|
del q_proj, k_proj, v_proj
|
||||||
|
|
||||||
|
return cls(fused_proj)
|
168
auto_gptq/nn_modules/fused_modules/mlp.py
Normal file
168
auto_gptq/nn_modules/fused_modules/mlp.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Union, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functorch.compile import memory_efficient_fusion
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def act_dropout(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
activation: Union[Callable, nn.Module],
|
||||||
|
dropout: float = 0.0
|
||||||
|
):
|
||||||
|
hidden_states = activation(hidden_states)
|
||||||
|
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||||
|
|
||||||
|
|
||||||
|
def dropout_res(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
dropout: float = 0.0
|
||||||
|
):
|
||||||
|
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||||
|
return torch.add(hidden_states, residual)
|
||||||
|
|
||||||
|
|
||||||
|
def act_dropout_res(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
activation: Union[Callable, nn.Module],
|
||||||
|
dropout: float = 0.0
|
||||||
|
):
|
||||||
|
hidden_states = activation(hidden_states)
|
||||||
|
hidden_states = hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||||
|
return torch.add(hidden_states, residual)
|
||||||
|
|
||||||
|
|
||||||
|
class NVFusedActDropoutRes(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
residual: bool = False,
|
||||||
|
is_cuda: bool = False
|
||||||
|
):
|
||||||
|
super(NVFusedActDropoutRes, self).__init__()
|
||||||
|
|
||||||
|
fn = partial(F.dropout, p=dropout)
|
||||||
|
if activation is not None and residual:
|
||||||
|
fn = partial(act_dropout_res, activation=activation, dropout=dropout)
|
||||||
|
elif activation is not None:
|
||||||
|
fn = partial(act_dropout, activation=activation, dropout=dropout)
|
||||||
|
elif residual:
|
||||||
|
fn = partial(dropout_res, dropout=dropout)
|
||||||
|
|
||||||
|
self.fn = fn
|
||||||
|
if is_cuda:
|
||||||
|
self.fn = memory_efficient_fusion(self.fn)
|
||||||
|
|
||||||
|
self.residual = residual
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
||||||
|
if self.residual:
|
||||||
|
return self.fn(hidden_states, residual)
|
||||||
|
else:
|
||||||
|
return self.fn(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||||
|
in_dropout: float = 0.0,
|
||||||
|
out_dropout: float = 0.0,
|
||||||
|
training: bool = False,
|
||||||
|
residual: bool = False
|
||||||
|
):
|
||||||
|
super(FusedMLP, self).__init__()
|
||||||
|
|
||||||
|
if activation is None:
|
||||||
|
activation = nn.Identity()
|
||||||
|
|
||||||
|
is_cuda = input_proj.weight.data.device.type == "cuda"
|
||||||
|
|
||||||
|
self.input_proj = input_proj
|
||||||
|
self.fused_op1 = NVFusedActDropoutRes(
|
||||||
|
activation=activation,
|
||||||
|
dropout=in_dropout if training else 0.0,
|
||||||
|
residual=False,
|
||||||
|
is_cuda=is_cuda
|
||||||
|
)
|
||||||
|
self.out_proj = out_proj
|
||||||
|
self.fused_op2 = NVFusedActDropoutRes(
|
||||||
|
activation=None,
|
||||||
|
dropout=out_dropout if training else 0.0,
|
||||||
|
residual=residual,
|
||||||
|
is_cuda=is_cuda
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
||||||
|
return self.fused_op2(self.out_proj(self.fused_op1(self.input_proj(hidden_states))), residual)
|
||||||
|
|
||||||
|
|
||||||
|
def gated_act_dropout(
|
||||||
|
gate_states: torch.Tensor,
|
||||||
|
up_states: torch.Tensor,
|
||||||
|
activation: Union[Callable, nn.Module],
|
||||||
|
dropout: float = 0.0
|
||||||
|
):
|
||||||
|
hidden_states = activation(gate_states) * up_states
|
||||||
|
return hidden_states if dropout == 0.0 else F.dropout(hidden_states, dropout)
|
||||||
|
|
||||||
|
|
||||||
|
class NVFusedGatedActDropout(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
is_cuda: bool = False
|
||||||
|
):
|
||||||
|
super(NVFusedGatedActDropout, self).__init__()
|
||||||
|
|
||||||
|
fn = partial(F.dropout, p=dropout)
|
||||||
|
if activation is not None:
|
||||||
|
fn = partial(gated_act_dropout, activation=activation, dropout=dropout)
|
||||||
|
|
||||||
|
self.fn = fn
|
||||||
|
if is_cuda:
|
||||||
|
self.fn = memory_efficient_fusion(self.fn)
|
||||||
|
|
||||||
|
def forward(self, gate_states: torch.Tensor, up_states):
|
||||||
|
if isinstance(self.fn, nn.Dropout):
|
||||||
|
return self.fn(gate_states * up_states)
|
||||||
|
return self.fn(gate_states, up_states)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedGatedMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_proj: nn.Linear,
|
||||||
|
out_proj: nn.Linear,
|
||||||
|
activation: Optional[Union[Callable, nn.Module]] = None,
|
||||||
|
in_dropout: float = 0.0,
|
||||||
|
out_dropout: float = 0.0,
|
||||||
|
training: bool = False
|
||||||
|
):
|
||||||
|
super(FusedGatedMLP, self).__init__()
|
||||||
|
|
||||||
|
if activation is None:
|
||||||
|
activation = nn.Identity()
|
||||||
|
|
||||||
|
self.input_proj = input_proj
|
||||||
|
self.fused_op = NVFusedGatedActDropout(
|
||||||
|
activation=activation,
|
||||||
|
dropout=in_dropout if training else 0.0,
|
||||||
|
is_cuda=input_proj.weight.data.device.type == "cuda"
|
||||||
|
)
|
||||||
|
self.out_proj = out_proj
|
||||||
|
self.out_dropout = nn.Dropout(out_dropout)
|
||||||
|
|
||||||
|
self.intermediate_size = self.input_proj.out_features // 2
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
hidden_states = self.input_proj(hidden_states)
|
||||||
|
return self.out_dropout(self.out_proj(self.fused_op(*hidden_states.chunk(chunks=2, dim=-1))))
|
|
@ -6,7 +6,7 @@ class GeneralQuantLinear(nn.Linear):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_features=quant_linear_module.infeatures,
|
in_features=quant_linear_module.infeatures,
|
||||||
out_features=quant_linear_module.outfeatures,
|
out_features=quant_linear_module.outfeatures,
|
||||||
bias=True
|
bias=quant_linear_module.bias is not None
|
||||||
)
|
)
|
||||||
self.infeatures = quant_linear_module.infeatures
|
self.infeatures = quant_linear_module.infeatures
|
||||||
self.outfeatures = quant_linear_module.outfeatures
|
self.outfeatures = quant_linear_module.outfeatures
|
||||||
|
@ -18,28 +18,47 @@ class GeneralQuantLinear(nn.Linear):
|
||||||
|
|
||||||
self.weight.data = quant_linear_module.qweight
|
self.weight.data = quant_linear_module.qweight
|
||||||
self.register_buffer('qweight', quant_linear_module.qweight)
|
self.register_buffer('qweight', quant_linear_module.qweight)
|
||||||
self.bias.data = quant_linear_module.bias
|
if quant_linear_module.bias is not None:
|
||||||
|
self.bias.data = quant_linear_module.bias
|
||||||
self.qweight.requires_grad = False
|
|
||||||
self.bias.requires_grad = False
|
|
||||||
|
|
||||||
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
||||||
self.register_buffer('scales', quant_linear_module.scales)
|
self.register_buffer('scales', quant_linear_module.scales)
|
||||||
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
||||||
|
|
||||||
|
# arg of qlinear_cuda and qlinear_cuda_old
|
||||||
if hasattr(quant_linear_module, "wf"):
|
if hasattr(quant_linear_module, "wf"):
|
||||||
self.wf = quant_linear_module.wf
|
self.wf = quant_linear_module.wf
|
||||||
|
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||||
if hasattr(quant_linear_module, "kernel_switch_threshold"):
|
if hasattr(quant_linear_module, "kernel_switch_threshold"):
|
||||||
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
|
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
|
||||||
|
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||||
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
||||||
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
||||||
|
# arg of qlinaer_cuda and qlinear_cuda_old
|
||||||
|
if hasattr(quant_linear_module, "autogptq_cuda"):
|
||||||
|
self.autogptq_cuda = quant_linear_module.autogptq_cuda
|
||||||
|
# arg of qlinear_cuda_old
|
||||||
|
if hasattr(quant_linear_module, "half_indim"):
|
||||||
|
self.half_indim = quant_linear_module.half_indim
|
||||||
|
# arg of qlinear_cuda_old
|
||||||
|
if hasattr(quant_linear_module, "use_cuda_fp16"):
|
||||||
|
self.use_cuda_fp16 = quant_linear_module.use_cuda_fp16
|
||||||
|
# args of qlinear_exllama
|
||||||
|
if hasattr(quant_linear_module, "_use_act_order"):
|
||||||
|
self._use_act_order = quant_linear_module._use_act_order
|
||||||
|
# arg of qlinaer_exllama
|
||||||
|
if hasattr(quant_linear_module, "width"):
|
||||||
|
self.width = quant_linear_module.width
|
||||||
|
# arg of qlinear_exllama
|
||||||
|
if hasattr(quant_linear_module, "q4"):
|
||||||
|
self.q4 = quant_linear_module.q4
|
||||||
|
|
||||||
self.trainable = quant_linear_module.trainable
|
self.trainable = quant_linear_module.trainable
|
||||||
|
|
||||||
self.forward = quant_linear_module.forward
|
self.forward = quant_linear_module.forward
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inject_to_model(cls, model, target_module_type):
|
def convert_to_torch_linear(cls, model: nn.Module, target_module_type: "QuantLinear"):
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, target_module_type):
|
if not isinstance(m, target_module_type):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -36,8 +36,6 @@ class QuantLinear(nn.Module):
|
||||||
global _autogptq_cuda_available
|
global _autogptq_cuda_available
|
||||||
if bits not in [2, 3, 4, 8]:
|
if bits not in [2, 3, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||||
if trainable:
|
|
||||||
_autogptq_cuda_available = False
|
|
||||||
|
|
||||||
self.infeatures = infeatures
|
self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures
|
self.outfeatures = outfeatures
|
||||||
|
@ -198,7 +196,7 @@ class QuantLinear(nn.Module):
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
if self.autogptq_cuda_available and (
|
if self.autogptq_cuda_available and (
|
||||||
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
|
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
|
||||||
):
|
) and not self.trainable:
|
||||||
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
|
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
|
||||||
if self.bits == 2:
|
if self.bits == 2:
|
||||||
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||||
|
|
|
@ -36,8 +36,7 @@ class QuantLinear(nn.Module):
|
||||||
global _autogptq_cuda_available
|
global _autogptq_cuda_available
|
||||||
if bits not in [2, 3, 4, 8]:
|
if bits not in [2, 3, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||||
if trainable:
|
|
||||||
_autogptq_cuda_available = False
|
|
||||||
self.infeatures = infeatures
|
self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures
|
self.outfeatures = outfeatures
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
@ -198,7 +197,7 @@ class QuantLinear(nn.Module):
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
if self.autogptq_cuda_available is True and (
|
if self.autogptq_cuda_available is True and (
|
||||||
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
|
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
|
||||||
):
|
) and not self.trainable:
|
||||||
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
|
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
|
||||||
if self.use_cuda_fp16:
|
if self.use_cuda_fp16:
|
||||||
x = x.half()
|
x = x.half()
|
||||||
|
|
|
@ -354,7 +354,7 @@ def get_gptq_peft_model(
|
||||||
train_mode: bool = False
|
train_mode: bool = False
|
||||||
):
|
):
|
||||||
if train_mode and not model.trainable:
|
if train_mode and not model.trainable:
|
||||||
model.enable_trainable_mode()
|
raise TypeError("model is not trainable, please load model with 'trainable=True'")
|
||||||
if train_mode and not peft_config:
|
if train_mode and not peft_config:
|
||||||
raise ValueError("peft_config not specified when in train mode.")
|
raise ValueError("peft_config not specified when in train mode.")
|
||||||
if not train_mode and not model_id:
|
if not train_mode and not model_id:
|
||||||
|
|
|
@ -87,30 +87,29 @@ def load_data(data_path, tokenizer, n_samples, max_new_tokens):
|
||||||
outputs = examples["output"]
|
outputs = examples["output"]
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
texts = []
|
outs = []
|
||||||
input_ids = []
|
input_ids = []
|
||||||
attention_mask = []
|
attention_mask = []
|
||||||
for istr, inp, opt in zip(instructions, inputs, outputs):
|
for istr, inp, opt in zip(instructions, inputs, outputs):
|
||||||
if inp:
|
if inp:
|
||||||
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
|
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
|
||||||
text = prompt + opt
|
|
||||||
else:
|
else:
|
||||||
prompt = f"Instruction:\n{istr}\nOutput:\n"
|
prompt = f"Instruction:\n{istr}\nOutput:\n"
|
||||||
text = prompt + opt
|
|
||||||
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
|
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokenized_data = tokenizer(text)
|
tokenized_data = tokenizer(prompt)
|
||||||
|
|
||||||
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
|
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
|
||||||
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
|
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
texts.append(text)
|
outs.append(opt)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"prompt": prompts
|
"prompt": prompts,
|
||||||
|
"output": outs
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset = Dataset.from_generator(dummy_gen)
|
dataset = Dataset.from_generator(dummy_gen)
|
||||||
|
@ -236,9 +235,9 @@ def main():
|
||||||
parser.add_argument("--use_triton", action="store_true")
|
parser.add_argument("--use_triton", action="store_true")
|
||||||
parser.add_argument("--use_safetensors", action="store_true")
|
parser.add_argument("--use_safetensors", action="store_true")
|
||||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||||
|
parser.add_argument("--inject_fused_attention", action="store_true")
|
||||||
|
parser.add_argument("--inject_fused_mlp", action="store_true")
|
||||||
parser.add_argument("--disable_exllama", action="store_true")
|
parser.add_argument("--disable_exllama", action="store_true")
|
||||||
parser.add_argument("--no_inject_fused_attention", action="store_true")
|
|
||||||
parser.add_argument("--no_inject_fused_mlp", action="store_true")
|
|
||||||
parser.add_argument("--num_samples", type=int, default=10)
|
parser.add_argument("--num_samples", type=int, default=10)
|
||||||
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
|
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
|
||||||
parser.add_argument("--cpu_max_memory", type=int, default=None)
|
parser.add_argument("--cpu_max_memory", type=int, default=None)
|
||||||
|
@ -277,8 +276,8 @@ def main():
|
||||||
use_triton=args.use_triton,
|
use_triton=args.use_triton,
|
||||||
use_safetensors=args.use_safetensors,
|
use_safetensors=args.use_safetensors,
|
||||||
use_fast_tokenizer=args.use_fast_tokenizer,
|
use_fast_tokenizer=args.use_fast_tokenizer,
|
||||||
inject_fused_attention=not args.no_inject_fused_attention,
|
inject_fused_attention=args.inject_fused_attention,
|
||||||
inject_fused_mlp=not args.no_inject_fused_mlp,
|
inject_fused_mlp=args.inject_fused_mlp,
|
||||||
disable_exllama=args.disable_exllama
|
disable_exllama=args.disable_exllama
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -289,7 +288,7 @@ def main():
|
||||||
|
|
||||||
if args.use_triton:
|
if args.use_triton:
|
||||||
logger.info("warmup triton, this may take a while.")
|
logger.info("warmup triton, this may take a while.")
|
||||||
model.warmup_triton()
|
model.warmup_triton(model)
|
||||||
|
|
||||||
logger.info("loading data")
|
logger.info("loading data")
|
||||||
examples = load_data(
|
examples = load_data(
|
||||||
|
|
|
@ -37,12 +37,18 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||||
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||||
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||||
|
parser.add_argument("--inject_fused_attention", action="store_true", help="Whether to inject fused attention")
|
||||||
|
parser.add_argument("--inject_fused_mlp", action="store_true", help="Whether to inject fused mlp")
|
||||||
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
|
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.model_name,
|
||||||
|
use_fast=args.use_fast_tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
if not tokenizer.pad_token_id:
|
if not tokenizer.pad_token_id:
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@ -68,8 +74,8 @@ if __name__ == "__main__":
|
||||||
model_basename=args.model_basename,
|
model_basename=args.model_basename,
|
||||||
use_safetensors=args.use_safetensors,
|
use_safetensors=args.use_safetensors,
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
inject_fused_mlp=False,
|
inject_fused_mlp=args.inject_fused_mlp,
|
||||||
inject_fused_attention=False,
|
inject_fused_attention=args.inject_fused_attention,
|
||||||
disable_exllama=args.disable_exllama
|
disable_exllama=args.disable_exllama
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -68,7 +68,10 @@ requirements = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"numpy",
|
"numpy",
|
||||||
"rouge",
|
"rouge",
|
||||||
"torch>=1.13.0",
|
"torch>=2.0.1",
|
||||||
|
"functorch",
|
||||||
|
"xformers>=0.0.20",
|
||||||
|
"vllm",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"transformers>=4.31.0",
|
"transformers>=4.31.0",
|
||||||
"peft"
|
"peft"
|
||||||
|
|
Loading…
Add table
Reference in a new issue