From 229b61e20ebbfb1b58a4a053bcd3811da67bcd9c Mon Sep 17 00:00:00 2001 From: PanQiWei <594557445@qq.com> Date: Fri, 14 Apr 2023 01:09:40 +0800 Subject: [PATCH] first init --- auto_gptq/__init__.py | 1 + auto_gptq/modeling/__init__.py | 14 + auto_gptq/modeling/_base.py | 328 +++++++++++++ auto_gptq/modeling/_const.py | 13 + auto_gptq/modeling/_utils.py | 39 ++ auto_gptq/modeling/bloom.py | 15 + auto_gptq/modeling/gpt_neox.py | 15 + auto_gptq/modeling/gptj.py | 15 + auto_gptq/modeling/llama.py | 15 + auto_gptq/modeling/opt.py | 23 + auto_gptq/modeling_auto.py | 54 +++ auto_gptq/quantization/ACKNOWLEDGEMENT.md | 1 + auto_gptq/quantization/__init__.py | 2 + auto_gptq/quantization/gptq.py | 180 +++++++ auto_gptq/quantization/quant.py | 349 ++++++++++++++ auto_gptq/quantization/quant_cuda.cpp | 70 +++ auto_gptq/quantization/quant_cuda_kernel.cu | 509 ++++++++++++++++++++ auto_gptq/quantization/setup_cuda.py | 10 + 18 files changed, 1653 insertions(+) create mode 100644 auto_gptq/__init__.py create mode 100644 auto_gptq/modeling/__init__.py create mode 100644 auto_gptq/modeling/_base.py create mode 100644 auto_gptq/modeling/_const.py create mode 100644 auto_gptq/modeling/_utils.py create mode 100644 auto_gptq/modeling/bloom.py create mode 100644 auto_gptq/modeling/gpt_neox.py create mode 100644 auto_gptq/modeling/gptj.py create mode 100644 auto_gptq/modeling/llama.py create mode 100644 auto_gptq/modeling/opt.py create mode 100644 auto_gptq/modeling_auto.py create mode 100644 auto_gptq/quantization/ACKNOWLEDGEMENT.md create mode 100644 auto_gptq/quantization/__init__.py create mode 100644 auto_gptq/quantization/gptq.py create mode 100644 auto_gptq/quantization/quant.py create mode 100644 auto_gptq/quantization/quant_cuda.cpp create mode 100644 auto_gptq/quantization/quant_cuda_kernel.cu create mode 100644 auto_gptq/quantization/setup_cuda.py diff --git a/auto_gptq/__init__.py b/auto_gptq/__init__.py new file mode 100644 index 0000000..4339aa0 --- /dev/null +++ b/auto_gptq/__init__.py @@ -0,0 +1 @@ +from modeling_auto import AutoGPTQModelForCausalLM diff --git a/auto_gptq/modeling/__init__.py b/auto_gptq/modeling/__init__.py new file mode 100644 index 0000000..7fab50d --- /dev/null +++ b/auto_gptq/modeling/__init__.py @@ -0,0 +1,14 @@ +from ._base import BaseQuantizeConfig +from bloom import * +from gpt_neox import * +from gptj import * +from llama import * +from opt import * + +GPTQ_CAUSAL_LM_MODEL_MAP = { + "bloom": BloomGPTQForCausalLM, + "gpt_neox": GPTNeoXGPTQForCausalLM, + "gptj": GPTJGPTQForCausalLM, + "llama": LlamaGPTQForCausalLM, + "opt": OPTGPTQForCausalLM +} diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py new file mode 100644 index 0000000..dc65f81 --- /dev/null +++ b/auto_gptq/modeling/_base.py @@ -0,0 +1,328 @@ +import json +from dataclasses import dataclass, field, fields +from logging import getLogger +from os.path import join +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import transformers +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel + +from ._const import * +from ._utils import * +from ..quantization import * + +logger = getLogger(__name__) + + +@dataclass +class BaseQuantizeConfig: + bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) + damp_percent: float = field(default=0.01) + desc_act: bool = field(default=True) + group_size: int = field(default=-1) + + def __post_init__(self): + fields_info = fields(self) + + if self.bits not in fields_info[0].metadata["choices"]: + raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("unless equal to -1, group_size must greater then 0.") + + def save_pretrained(self, save_dir: str): + with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2) + + @classmethod + def from_pretrained(cls, save_dir: str): + with open(join(save_dir, "quantize_config.json"), "r", encoding="utf-8") as f: + return cls(**json.load(f)) + + def to_dict(self): + return { + "bits": self.bits, + "damp_percent": self.damp_percent, + "desc_act": self.desc_act, + "group_size": self.group_size + } + + +class BaseGPTQForCausalLM: + layers_block_name: str = None + outside_layer_modules: List[str] = None + inside_layer_modules: List[List[str]] = None + + def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig): + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.quantize_config = quantize_config + + @property + def quantized(self): + return self._quantized + + def _move_outside_layer_modules(self, device): + for module_name in self.outside_layer_modules: + module = get_module_by_name(self.model, module_name) + if module is not None: + module.to(device) + + @staticmethod + def _resize_attention_mask(attention_mask: List[torch.LongTensor]): + return attention_mask + + def quantize(self, examples: List[Dict[str, torch.LongTensor]]): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + layer_inputs = [] + attention_masks = [] + layer_outputs = [] + + class LayerHijacker(nn.Module): + """hijack layer's forward pass to cache data""" + + def __init__(self, m): + super().__init__() + self.module = m + + def forward(self, inp, **kwargs): + bsz = inp.size(0) + for i in range(bsz): + layer_inputs.append(inp[i].to(CPU)) + attention_masks.append(kwargs["attention_mask"][i].to(CPU)) + raise ValueError + + forward_pass_use_cache = self.model.config.use_cache + self.model.config.use_cache = False + + num_examples = len(examples) + layers = get_module_by_name(self.model, self.layers_block_name) + + layers[0] = layers[0].to(CUDA) + self._move_outside_layer_modules(CUDA) + + # get inputs for first layer + layers[0] = LayerHijacker(layers[0]) + for example in examples: + for k, v in example.items(): + if k == "input_ids" and len(v.shape) == 1: + v = v.unsqueeze(0) + example[k] = v.to(CUDA) + try: + self.model(**example) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + self._move_outside_layer_modules(CPU) + + torch.cuda.empty_cache() + + # resize attention mask for some special models + attention_masks = self._resize_attention_mask(attention_masks) + + quantizers = {} + for i in range(len(layers)): + logger.info(f"Start quantizing layer {i + 1}/{len(layers)}") + layer = layers[i].to(CUDA) + + full = find_layers(layer) + for names in self.inside_layer_modules: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( + self.quantize_config.bits, + perchannel=True, + sym=True, + mse=False + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(num_examples): + layer_input = layer_inputs[j].unsqueeze(0).to("cuda:0") + layer_attention_mask = attention_masks[j].to("cuda:0") + layer(layer_input, attention_mask=layer_attention_mask) + for h in handles: + h.remove() + + for name in subset: + logger.info(f'Quantizing {name} in layer {i + 1}/{len(layers)}...') + scale, zero, g_idx = gptq[name].fasterquant( + percdamp=self.quantize_config.damp_percent, + groupsize=self.quantize_config.group_size, + actorder=self.quantize_config.desc_act + ) + quantizers[f'{self.layers_block_name}.{i}.{name}'] = ( + gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu() + ) + gptq[name].free() + + for j in range(num_examples): + layer_input = layer_inputs[j].unsqueeze(0).to(CUDA) + layer_attention_mask = attention_masks[j].to(CUDA) + layer_output = layer(layer_input, attention_mask=layer_attention_mask)[0][0].cpu() + layer_outputs.append(layer_output) + + layers[i] = layer.to(CPU) + del layer + del gptq + torch.cuda.empty_cache() + + layer_inputs, layer_outputs = layer_outputs, [] + + pack_model( + model=self.model, + quantizers=quantizers, + bits=self.quantize_config.bits, + group_size=self.quantize_config.group_size + ) + self._quantized = True + self.model.config.use_cache = forward_pass_use_cache + + def generate(self, inputs, **kwargs): + """shortcut for model.generate""" + return self.model.generate(inputs, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def save_quantized(self, save_dir: str, use_safetensors: bool = False): + """save quantized model and configs to local disk""" + if use_safetensors: + try: + import safetensors + except ImportError: + logger.warning("safetensors is not installed, will save to .bin file.") + use_safetensors = False + else: + from safetensors.torch import save_file as safe_save + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to(CPU) + + model_save_name = f"gptq_model-{self.quantize_config.bits}bit" + if use_safetensors: + model_save_name += ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + safe_save(state_dict, join(save_dir, model_save_name)) + else: + model_save_name += ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + self.quantize_config.save_pretrained(save_dir) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + quantize_config: BaseQuantizeConfig, + bf16: bool = False, + **model_init_kwargs + ): + """load un-quantized pretrained model to cpu""" + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + config = AutoConfig.from_pretrained(model_init_kwargs["pretrained_model_name_or_path"]) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["device_map"] = None + model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16 + model_init_kwargs["low_cpu_mem_usage"] = False + + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs) + model.seqlen = model.config.max_position_embeddings + model.eval() + + return cls(model, False, quantize_config) + + @classmethod + def from_quantized( + cls, + save_dir: str, + device: str = "cpu", + use_safetensors: bool = False + ): + """load quantized model from local disk""" + if use_safetensors: + try: + import safetensors + except ImportError: + logger.warning("safetensors is not installed, will load .bin file.") + use_safetensors = False + else: + from safetensors.torch import load_file as safe_load + + config = AutoConfig.from_pretrained(save_dir) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + quantize_config = BaseQuantizeConfig.from_pretrained(save_dir) + + model_save_name = f"gptq_model-{quantize_config.bits}bit" + if use_safetensors: + model_save_name += ".safetensors" + else: + model_save_name += ".bin" + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config, **{"low_cpu_mem_usage": False, "device_map": None}) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant(model, layers, quantize_config.bits, quantize_config.group_size) + + if model_save_name.endswith('.safetensors'): + model.load_state_dict(safe_load(model_save_name, "cpu")) + else: + model.load_state_dict(torch.load(model_save_name)) + model.seqlen = model.config.max_position_embeddings + + model.eval() + model.to(device) + + return model + + +__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"] diff --git a/auto_gptq/modeling/_const.py b/auto_gptq/modeling/_const.py new file mode 100644 index 0000000..6a42760 --- /dev/null +++ b/auto_gptq/modeling/_const.py @@ -0,0 +1,13 @@ +from packaging.version import parse as parse_version + +from torch import device +from transformers import __version__ as transformers_version + +CPU = device("cpu") +CUDA = device("cuda:0") + +SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt"] +if parse_version(transformers_version) >= parse_version("v4.28.0"): + SUPPORTED_MODELS.append("llama") + +__all__ = ["CPU", "CUDA", "SUPPORTED_MODELS"] diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py new file mode 100644 index 0000000..691471e --- /dev/null +++ b/auto_gptq/modeling/_utils.py @@ -0,0 +1,39 @@ +from logging import getLogger + +import torch.nn as nn + +from ..quantization import make_quant, QuantLinear + +logger = getLogger(__name__) + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + return res + + +def get_module_by_name(model, module_name: str): + for name, module in model.named_modules(): + if name.startswith(module_name): + return module + + +def pack_model(model, quantizers, bits, group_size): + model.cpu() + logger.info('Packing model...') + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant(model, quantizers, bits, group_size) + qlayers = find_layers(model, [QuantLinear]) + for name in qlayers: + logger.info(name) + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + logger.info('Model packed.') + + +__all__ = ["find_layers", "get_module_by_name", "pack_model"] diff --git a/auto_gptq/modeling/bloom.py b/auto_gptq/modeling/bloom.py new file mode 100644 index 0000000..8ce4325 --- /dev/null +++ b/auto_gptq/modeling/bloom.py @@ -0,0 +1,15 @@ +from ._base import * + + +class BloomGPTQForCausalLM(BaseGPTQForCausalLM): + layers_block_name = "transformer.h" + outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"] + inside_layer_modules = [ + ["self_attention.query_key_value"], + ["self_attention.dense"], + ["mlp.dense_h_to_4h"], + ["mlp.dense_4h_to_h"] + ] + + +__all__ = ["BloomGPTQForCausalLM"] diff --git a/auto_gptq/modeling/gpt_neox.py b/auto_gptq/modeling/gpt_neox.py new file mode 100644 index 0000000..795760b --- /dev/null +++ b/auto_gptq/modeling/gpt_neox.py @@ -0,0 +1,15 @@ +from ._base import * + + +class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM): + layers_block_name = "gpt_neox.layers" + outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"] + inside_layer_modules = [ + ["attention.query_key_value"], + ["attention.dense"], + ["mlp.dense_h_to_4h"], + ["mlp.dense_4h_to_h"] + ] + + +__all__ = ["GPTNeoXGPTQForCausalLM"] diff --git a/auto_gptq/modeling/gptj.py b/auto_gptq/modeling/gptj.py new file mode 100644 index 0000000..94a3cdc --- /dev/null +++ b/auto_gptq/modeling/gptj.py @@ -0,0 +1,15 @@ +from ._base import * + + +class GPTJGPTQForCausalLM(BaseGPTQForCausalLM): + layers_block_name = "transformer.h" + outside_layer_modules = ["transformer.wte", "transformer.ln_f"] + inside_layer_modules = [ + ["attn.k_proj", "attn.v_proj", "attn.q_proj"], + ["attn.out_proj"], + ["mlp.fc_in"], + ["mlp.fc_out"] + ] + + +__all__ = ["GPTJGPTQForCausalLM"] diff --git a/auto_gptq/modeling/llama.py b/auto_gptq/modeling/llama.py new file mode 100644 index 0000000..06bbadf --- /dev/null +++ b/auto_gptq/modeling/llama.py @@ -0,0 +1,15 @@ +from ._base import * + + +class LlamaGPTQForCausalLM(BaseGPTQForCausalLM): + layers_block_name = "model.layers" + outside_layer_modules = ["model.embed_tokens", "model.norm"] + inside_layer_modules = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], + ["self_attn.o_proj"], + ["mlp.up_proj", "mlp.gate_proj"], + ["mlp.down_proj"] + ] + + +__all__ = ["LlamaGPTQForCausalLM"] diff --git a/auto_gptq/modeling/opt.py b/auto_gptq/modeling/opt.py new file mode 100644 index 0000000..dfbb1f3 --- /dev/null +++ b/auto_gptq/modeling/opt.py @@ -0,0 +1,23 @@ +from ._base import * + + +class OPTGPTQForCausalLM(BaseGPTQForCausalLM): + layers_block_name = "model.decoder.layers" + outside_layer_modules = [ + "model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out", + "model.decoder.project_in", "model.decoder.final_layer_norm" + ] + inside_layer_modules = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], + ["self_attn.out_proj"], + ["fc1"], + ["fc2"] + ] + + @staticmethod + def _resize_attention_mask(attention_mask): + attention_mask = [attention_mask.unsqueeze(1) for attention_mask in attention_mask] + return attention_mask + + +__all__ = ["OPTGPTQForCausalLM"] diff --git a/auto_gptq/modeling_auto.py b/auto_gptq/modeling_auto.py new file mode 100644 index 0000000..25dfecb --- /dev/null +++ b/auto_gptq/modeling_auto.py @@ -0,0 +1,54 @@ +from transformers import AutoConfig + +from .modeling import BaseQuantizeConfig, GPTQ_CAUSAL_LM_MODEL_MAP +from .modeling._const import SUPPORTED_MODELS + + +def check_and_get_model_type(model_dir): + config = AutoConfig.from_pretrained(model_dir) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + model_type = config.model_type + return model_type + + +class AutoGPTQModelForCausalLM: + def __init__(self): + raise EnvironmentError( + "AutoGPTQModelForCausalLM is designed to be instantiated\n" + "using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n" + "using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model." + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + quantize_config: BaseQuantizeConfig, + bf16: bool = False, + **model_init_kwargs + ): + model_type = check_and_get_model_type(pretrained_model_name_or_path) + return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + quantize_config=quantize_config, + bf16=bf16, + **model_init_kwargs + ) + + @classmethod + def from_quantized( + cls, + save_dir: str, + device: str = "cpu", + use_safetensors: bool = False + ): + model_type = check_and_get_model_type(save_dir) + return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( + save_dir=save_dir, + device=device, + use_safetensors=use_safetensors + ) + + +__all__ = ["AutoGPTQModelForCausalLM"] diff --git a/auto_gptq/quantization/ACKNOWLEDGEMENT.md b/auto_gptq/quantization/ACKNOWLEDGEMENT.md new file mode 100644 index 0000000..7c8dedc --- /dev/null +++ b/auto_gptq/quantization/ACKNOWLEDGEMENT.md @@ -0,0 +1 @@ +The codes in this directory are mainly referenced from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) \ No newline at end of file diff --git a/auto_gptq/quantization/__init__.py b/auto_gptq/quantization/__init__.py new file mode 100644 index 0000000..5b1362e --- /dev/null +++ b/auto_gptq/quantization/__init__.py @@ -0,0 +1,2 @@ +from .gptq import * +from .quant import * diff --git a/auto_gptq/quantization/gptq.py b/auto_gptq/quantization/gptq.py new file mode 100644 index 0000000..4cae69a --- /dev/null +++ b/auto_gptq/quantization/gptq.py @@ -0,0 +1,180 @@ +import math +import os +import time +from logging import getLogger + +import torch +import torch.nn as nn +import transformers + +from .quant import * + +logger = getLogger(__name__) + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class GPTQ: + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + if os.environ.get("DEBUG"): + self.inp1 = inp + self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def fasterquant( + self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False + ): + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if os.environ.get("DEBUG"): + self.layer.weight.data[:, :i2] = Q[:, :i2] + self.layer.weight.data[:, i2:] = W[:, i2:] + logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + logger.debug(torch.sum(Losses)) + + torch.cuda.synchronize() + logger.info(f'duration: {(time.time() - tick)}') + logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}') + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + if os.environ.get("DEBUG"): + logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx + + def free(self): + if os.environ.get("DEBUG"): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() + +__all__ = ["GPTQ"] diff --git a/auto_gptq/quantization/quant.py b/auto_gptq/quantization/quant.py new file mode 100644 index 0000000..950a0f4 --- /dev/null +++ b/auto_gptq/quantization/quant.py @@ -0,0 +1,349 @@ +import math +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn + + +logger = getLogger(__name__) + + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure( + self, + bits, perchannel=False, sym=True, + mse=False, norm=2.4, grid=100, maxshrink=.8, + trits=False + ): + + self.maxq = torch.tensor(2 ** bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +try: + import quant_cuda + is_cuda = True +except: + logger.warning('CUDA extension not installed.') + is_cuda = False + + +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + + +class QuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): + super().__init__() + if bits not in [2, 3, 4, 8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), + dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2, 4, 8]: + self.register_buffer('wf', torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0), + persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], ], + dtype=torch.int32).reshape(1, 3, 12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to( + torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1; + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + if self.is_cuda is True and ( + self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2, 4, 8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0)).to( + torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1)).to( + torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand(-1, -1, -1, + 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) + zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(-1, -1, + 12, -1) + weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 + weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) + weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + +__all__ = [ + "quantize", + "make_quant", + "Quantizer", + "QuantLinear" +] diff --git a/auto_gptq/quantization/quant_cuda.cpp b/auto_gptq/quantization/quant_cuda.cpp new file mode 100644 index 0000000..1dbfcde --- /dev/null +++ b/auto_gptq/quantization/quant_cuda.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); +} diff --git a/auto_gptq/quantization/quant_cuda_kernel.cu b/auto_gptq/quantization/quant_cuda_kernel.cu new file mode 100644 index 0000000..27addff --- /dev/null +++ b/auto_gptq/quantization/quant_cuda_kernel.cu @@ -0,0 +1,509 @@ +#include +#include +#include +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +__device__ double atomicAdd( + double* address, + double val +) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 16; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); + int k_bit = (k % 16) * 2; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + unsigned int z_tmp; + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; + int k_mod = k % 32; + int k_bit; + + if (k_mod != 10){ + if (k_mod != 21){ + k_bit = k_mod; + if (k_bit > 21){ + k_bit -= 22; + k_bit *= 3; + k_bit += 2; + k_w += 2; + } else if (k_bit > 10){ + k_bit -= 11; + k_bit *= 3; + k_bit += 1; + k_w += 1; + } else { + k_bit *= 3; + } + } else { + k_w += 1; + } + } + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scalar_t((z_tmp) + 1); + } else { + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + if (k_mod == 10) { + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); + } else if (k_mod == 21){ + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); + } else { + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); + } + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } +} diff --git a/auto_gptq/quantization/setup_cuda.py b/auto_gptq/quantization/setup_cuda.py new file mode 100644 index 0000000..6f05634 --- /dev/null +++ b/auto_gptq/quantization/setup_cuda.py @@ -0,0 +1,10 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup( + name='quant_cuda', + ext_modules=[cpp_extension.CUDAExtension( + 'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension} +)