first init

This commit is contained in:
PanQiWei 2023-04-14 01:09:40 +08:00
parent fe0ed2cd5d
commit 229b61e20e
18 changed files with 1653 additions and 0 deletions

1
auto_gptq/__init__.py Normal file
View file

@ -0,0 +1 @@
from modeling_auto import AutoGPTQModelForCausalLM

View file

@ -0,0 +1,14 @@
from ._base import BaseQuantizeConfig
from bloom import *
from gpt_neox import *
from gptj import *
from llama import *
from opt import *
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM
}

328
auto_gptq/modeling/_base.py Normal file
View file

@ -0,0 +1,328 @@
import json
from dataclasses import dataclass, field, fields
from logging import getLogger
from os.path import join
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from ._const import *
from ._utils import *
from ..quantization import *
logger = getLogger(__name__)
@dataclass
class BaseQuantizeConfig:
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
group_size: int = field(default=-1)
def __post_init__(self):
fields_info = fields(self)
if self.bits not in fields_info[0].metadata["choices"]:
raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")
def save_pretrained(self, save_dir: str):
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def from_pretrained(cls, save_dir: str):
with open(join(save_dir, "quantize_config.json"), "r", encoding="utf-8") as f:
return cls(**json.load(f))
def to_dict(self):
return {
"bits": self.bits,
"damp_percent": self.damp_percent,
"desc_act": self.desc_act,
"group_size": self.group_size
}
class BaseGPTQForCausalLM:
layers_block_name: str = None
outside_layer_modules: List[str] = None
inside_layer_modules: List[List[str]] = None
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
@property
def quantized(self):
return self._quantized
def _move_outside_layer_modules(self, device):
for module_name in self.outside_layer_modules:
module = get_module_by_name(self.model, module_name)
if module is not None:
module.to(device)
@staticmethod
def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
return attention_mask
def quantize(self, examples: List[Dict[str, torch.LongTensor]]):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
layer_inputs = []
attention_masks = []
layer_outputs = []
class LayerHijacker(nn.Module):
"""hijack layer's forward pass to cache data"""
def __init__(self, m):
super().__init__()
self.module = m
def forward(self, inp, **kwargs):
bsz = inp.size(0)
for i in range(bsz):
layer_inputs.append(inp[i].to(CPU))
attention_masks.append(kwargs["attention_mask"][i].to(CPU))
raise ValueError
forward_pass_use_cache = self.model.config.use_cache
self.model.config.use_cache = False
num_examples = len(examples)
layers = get_module_by_name(self.model, self.layers_block_name)
layers[0] = layers[0].to(CUDA)
self._move_outside_layer_modules(CUDA)
# get inputs for first layer
layers[0] = LayerHijacker(layers[0])
for example in examples:
for k, v in example.items():
if k == "input_ids" and len(v.shape) == 1:
v = v.unsqueeze(0)
example[k] = v.to(CUDA)
try:
self.model(**example)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
self._move_outside_layer_modules(CPU)
torch.cuda.empty_cache()
# resize attention mask for some special models
attention_masks = self._resize_attention_mask(attention_masks)
quantizers = {}
for i in range(len(layers)):
logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
layer = layers[i].to(CUDA)
full = find_layers(layer)
for names in self.inside_layer_modules:
subset = {n: full[n] for n in names}
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer = Quantizer()
gptq[name].quantizer.configure(
self.quantize_config.bits,
perchannel=True,
sym=True,
mse=False
)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(num_examples):
layer_input = layer_inputs[j].unsqueeze(0).to("cuda:0")
layer_attention_mask = attention_masks[j].to("cuda:0")
layer(layer_input, attention_mask=layer_attention_mask)
for h in handles:
h.remove()
for name in subset:
logger.info(f'Quantizing {name} in layer {i + 1}/{len(layers)}...')
scale, zero, g_idx = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
groupsize=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act
)
quantizers[f'{self.layers_block_name}.{i}.{name}'] = (
gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()
)
gptq[name].free()
for j in range(num_examples):
layer_input = layer_inputs[j].unsqueeze(0).to(CUDA)
layer_attention_mask = attention_masks[j].to(CUDA)
layer_output = layer(layer_input, attention_mask=layer_attention_mask)[0][0].cpu()
layer_outputs.append(layer_output)
layers[i] = layer.to(CPU)
del layer
del gptq
torch.cuda.empty_cache()
layer_inputs, layer_outputs = layer_outputs, []
pack_model(
model=self.model,
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size
)
self._quantized = True
self.model.config.use_cache = forward_pass_use_cache
def generate(self, inputs, **kwargs):
"""shortcut for model.generate"""
return self.model.generate(inputs, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def save_quantized(self, save_dir: str, use_safetensors: bool = False):
"""save quantized model and configs to local disk"""
if use_safetensors:
try:
import safetensors
except ImportError:
logger.warning("safetensors is not installed, will save to .bin file.")
use_safetensors = False
else:
from safetensors.torch import save_file as safe_save
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to(CPU)
model_save_name = f"gptq_model-{self.quantize_config.bits}bit"
if use_safetensors:
model_save_name += ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
safe_save(state_dict, join(save_dir, model_save_name))
else:
model_save_name += ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
**model_init_kwargs
):
"""load un-quantized pretrained model to cpu"""
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
config = AutoConfig.from_pretrained(model_init_kwargs["pretrained_model_name_or_path"])
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["device_map"] = None
model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16
model_init_kwargs["low_cpu_mem_usage"] = False
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model.seqlen = model.config.max_position_embeddings
model.eval()
return cls(model, False, quantize_config)
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
use_safetensors: bool = False
):
"""load quantized model from local disk"""
if use_safetensors:
try:
import safetensors
except ImportError:
logger.warning("safetensors is not installed, will load .bin file.")
use_safetensors = False
else:
from safetensors.torch import load_file as safe_load
config = AutoConfig.from_pretrained(save_dir)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
quantize_config = BaseQuantizeConfig.from_pretrained(save_dir)
model_save_name = f"gptq_model-{quantize_config.bits}bit"
if use_safetensors:
model_save_name += ".safetensors"
else:
model_save_name += ".bin"
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, **{"low_cpu_mem_usage": False, "device_map": None})
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, quantize_config.bits, quantize_config.group_size)
if model_save_name.endswith('.safetensors'):
model.load_state_dict(safe_load(model_save_name, "cpu"))
else:
model.load_state_dict(torch.load(model_save_name))
model.seqlen = model.config.max_position_embeddings
model.eval()
model.to(device)
return model
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]

View file

@ -0,0 +1,13 @@
from packaging.version import parse as parse_version
from torch import device
from transformers import __version__ as transformers_version
CPU = device("cpu")
CUDA = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt"]
if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA", "SUPPORTED_MODELS"]

View file

@ -0,0 +1,39 @@
from logging import getLogger
import torch.nn as nn
from ..quantization import make_quant, QuantLinear
logger = getLogger(__name__)
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
def get_module_by_name(model, module_name: str):
for name, module in model.named_modules():
if name.startswith(module_name):
return module
def pack_model(model, quantizers, bits, group_size):
model.cpu()
logger.info('Packing model...')
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(model, quantizers, bits, group_size)
qlayers = find_layers(model, [QuantLinear])
for name in qlayers:
logger.info(name)
quantizers[name], scale, zero, g_idx = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
logger.info('Model packed.')
__all__ = ["find_layers", "get_module_by_name", "pack_model"]

View file

@ -0,0 +1,15 @@
from ._base import *
class BloomGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]
]
__all__ = ["BloomGPTQForCausalLM"]

View file

@ -0,0 +1,15 @@
from ._base import *
class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "gpt_neox.layers"
outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"]
inside_layer_modules = [
["attention.query_key_value"],
["attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]
]
__all__ = ["GPTNeoXGPTQForCausalLM"]

View file

@ -0,0 +1,15 @@
from ._base import *
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
["attn.k_proj", "attn.v_proj", "attn.q_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"]
]
__all__ = ["GPTJGPTQForCausalLM"]

View file

@ -0,0 +1,15 @@
from ._base import *
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"]
]
__all__ = ["LlamaGPTQForCausalLM"]

23
auto_gptq/modeling/opt.py Normal file
View file

@ -0,0 +1,23 @@
from ._base import *
class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
layers_block_name = "model.decoder.layers"
outside_layer_modules = [
"model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",
"model.decoder.project_in", "model.decoder.final_layer_norm"
]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.out_proj"],
["fc1"],
["fc2"]
]
@staticmethod
def _resize_attention_mask(attention_mask):
attention_mask = [attention_mask.unsqueeze(1) for attention_mask in attention_mask]
return attention_mask
__all__ = ["OPTGPTQForCausalLM"]

View file

@ -0,0 +1,54 @@
from transformers import AutoConfig
from .modeling import BaseQuantizeConfig, GPTQ_CAUSAL_LM_MODEL_MAP
from .modeling._const import SUPPORTED_MODELS
def check_and_get_model_type(model_dir):
config = AutoConfig.from_pretrained(model_dir)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
class AutoGPTQModelForCausalLM:
def __init__(self):
raise EnvironmentError(
"AutoGPTQModelForCausalLM is designed to be instantiated\n"
"using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n"
"using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model."
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
**model_init_kwargs
):
model_type = check_and_get_model_type(pretrained_model_name_or_path)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
bf16=bf16,
**model_init_kwargs
)
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
use_safetensors: bool = False
):
model_type = check_and_get_model_type(save_dir)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
save_dir=save_dir,
device=device,
use_safetensors=use_safetensors
)
__all__ = ["AutoGPTQModelForCausalLM"]

View file

@ -0,0 +1 @@
The codes in this directory are mainly referenced from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)

View file

@ -0,0 +1,2 @@
from .gptq import *
from .quant import *

View file

@ -0,0 +1,180 @@
import math
import os
import time
from logging import getLogger
import torch
import torch.nn as nn
import transformers
from .quant import *
logger = getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
q = quantize(
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
torch.cuda.synchronize()
logger.info(f'duration: {(time.time() - tick)}')
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx
def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
__all__ = ["GPTQ"]

View file

@ -0,0 +1,349 @@
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
logger = getLogger(__name__)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(
self,
bits, perchannel=False, sym=True,
mse=False, norm=2.4, grid=100, maxshrink=.8,
trits=False
):
self.maxq = torch.tensor(2 ** bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
try:
import quant_cuda
is_cuda = True
except:
logger.warning('CUDA extension not installed.')
is_cuda = False
def make_quant(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer('qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
# is performed by unpacking the weights and using torch.matmul
if self.bits in [2, 4, 8]:
self.register_buffer('wf', torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0),
persistent=False)
elif self.bits == 3:
self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], ],
dtype=torch.int32).reshape(1, 3, 12), persistent=False)
self.kernel_switch_threshold = kernel_switch_threshold
self.is_cuda = is_cuda
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(
torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1;
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
if self.is_cuda is True and (
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold):
out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32)
if self.bits == 2:
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
out = out.half()
else:
if self.bits in [2, 4, 8]:
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0)).to(
torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1)).to(
torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
elif self.bits == 3:
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand(-1, -1, -1,
12)
zeros = (zeros >> self.wf.unsqueeze(0))
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4)
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6)
zeros = zeros & 0x7
zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2)
zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)
weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(-1, -1,
12, -1)
weight = (weight >> self.wf.unsqueeze(-1)) & 0x7
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4)
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6)
weight = weight & 0x7
weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]))
out = torch.matmul(x.half(), weights)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
__all__ = [
"quantize",
"make_quant",
"Quantizer",
"QuantLinear"
]

View file

@ -0,0 +1,70 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant2matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant4matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant4matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant8matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant8matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
}

View file

@ -0,0 +1,509 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// atomicAdd for double-precision floating-point numbers on hardware with
// compute capability < 6.0 from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
__device__ double atomicAdd(
double* address,
double val
) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull,
assumed,
__double_as_longlong(val + __longlong_as_double(assumed))
);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
const int BLOCKWIDTH = 256;
const int BLOCKHEIGHT2 = 16;
const int BLOCKHEIGHT3 = 24;
const int BLOCKHEIGHT4 = 32;
const int BLOCKHEIGHT8 = 64;
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
__device__ inline int as_int(int i) {
return *reinterpret_cast<int*>(&i);
}
void vecquant2matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 16;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 16;
int z_mod = (w % 16) * 2;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 16);
int k_bit = (k % 16) * 2;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant3matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = (h / 3) * 32;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = (w / 32) * 3;
int z_mod = w % 32;
int z_bit;
unsigned int z_tmp;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit -= 22;
z_bit *= 3;
z_bit += 2;
z_w += 2;
} else if (z_bit > 10){
z_bit -= 11;
z_bit *= 3;
z_bit += 1;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 32) * 3;
int k_mod = k % 32;
int k_bit;
if (k_mod != 10){
if (k_mod != 21){
k_bit = k_mod;
if (k_bit > 21){
k_bit -= 22;
k_bit *= 3;
k_bit += 2;
k_w += 2;
} else if (k_bit > 10){
k_bit -= 11;
k_bit *= 3;
k_bit += 1;
k_w += 1;
} else {
k_bit *= 3;
}
} else {
k_w += 1;
}
}
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scalar_t((z_tmp) + 1);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scalar_t((z_tmp) + 1);
} else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
}
if (k_mod == 10) {
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
} else if (k_mod == 21){
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6);
} else {
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7);
}
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant4matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 8;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 8);
int k_bit = (k % 8) * 4;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 4;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4);
int k_bit = (k % 4) * 8;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}

View file

@ -0,0 +1,10 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(
name='quant_cuda',
ext_modules=[cpp_extension.CUDAExtension(
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
)],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)