Merge branch 'main' into xformers_integration

This commit is contained in:
PanQiWei 2023-08-05 18:02:00 +08:00
commit 801610367d
9 changed files with 387 additions and 145 deletions

82
.github/workflows/build_wheels_rocm.yml vendored Normal file
View file

@ -0,0 +1,82 @@
name: Build AutoGPTQ Wheels with ROCm
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
strategy:
matrix:
os: [ubuntu-latest]
python: ["3.8", "3.9", "3.10"] # what's the point?
rocm: ["5.4.2", "5.5", "5.6"]
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }}
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
with:
ref: 'main'
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Free disk space
run: |
df -h
echo "Removing large packages"
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel
df -h
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get clean
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get autoclean -y >/dev/null 2>&1
df -h
echo "https://github.com/actions/virtual-environments/issues/709"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
df -h
echo "remove big /usr/local"
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
df -h
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
sudo rm -rf /usr/share/swift > /dev/null 2>&1
df -h
- name: Set up environment
run: |
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb
else
export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb
fi
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/jammy/$ROCM_DL_FILE
sudo dpkg -i $ROCM_DL_FILE
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends rocthrust-dev
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm }}
python -m pip install --upgrade build setuptools wheel ninja
- name: Build wheels
run: |
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
with:
name: 'wheels'
path: ./dist/*.whl

View file

@ -19,6 +19,7 @@
## News or Update ## News or Update
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`. - 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc. - 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
- 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub. - 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub.
@ -31,7 +32,7 @@
### Inference Speed ### Inference Speed
> The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better). > The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
> >
> The quantized model is loaded using the setup that can gain the fastest inference speed. > The quantized model is loaded using the setup that can gain the fastest inference speed.
| model | GPU | num_beams | fp16 | gptq-int4 | | model | GPU | num_beams | fp16 | gptq-int4 |
@ -53,14 +54,17 @@ For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200
### Quick Installation ### Quick Installation
You can install the latest stable release of AutoGPTQ from pip: You can install the latest stable release of AutoGPTQ from pip:
```shell ```shell
pip install auto-gptq pip install auto-gptq
``` ```
Start from v0.2.0, you can download pre-build wheel that satisfied your environment setup from each version's release assets and install it to skip building stage for the fastest installation speed. For example: Start from v0.2.0, you can download pre-build wheel that satisfied your environment setup from each version's release assets and install it to skip building stage for the fastest installation speed. For example:
```shell ```shell
# firstly, cd the directory where the wheel saved, then execute command below # firstly, cd the directory where the wheel saved, then execute command below
pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # install v0.2.0 auto_gptq pre-build wheel for linux in an environment whose python=3.10 and cuda=11.8 pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # install v0.2.0 auto_gptq pre-build wheel for linux in an environment whose python=3.10 and cuda=11.8
``` ```
#### disable cuda extensions #### disable cuda extensions
By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using: By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using:
```shell ```shell
@ -95,6 +99,12 @@ Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch
Use `.[triton]` if you want to integrate with triton and it's available on your operating system. Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
```
ROCM_VERSION=5.6 pip install .
```
</details> </details>
## Quick Tour ## Quick Tour
@ -102,7 +112,7 @@ Use `.[triton]` if you want to integrate with triton and it's available on your
### Quantization and Inference ### Quantization and Inference
> warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good. > warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization: Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
```python ```python
from transformers import AutoTokenizer, TextGenerationPipeline from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
@ -125,7 +135,7 @@ examples = [
quantize_config = BaseQuantizeConfig( quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128 group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
) )
# load un-quantized model, by default, the model will always be loaded into CPU memory # load un-quantized model, by default, the model will always be loaded into CPU memory
@ -140,7 +150,7 @@ model.save_quantized(quantized_model_dir)
# save quantized model using safetensors # save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True) model.save_quantized(quantized_model_dir, use_safetensors=True)
# push quantized model to Hugging Face Hub. # push quantized model to Hugging Face Hub.
# to use use_auth_token=True, Login first via huggingface-cli login. # to use use_auth_token=True, Login first via huggingface-cli login.
# or pass explcit token with: use_auth_token="hf_xxxxxxx" # or pass explcit token with: use_auth_token="hf_xxxxxxx"
# (uncomment the following three lines to enable this feature) # (uncomment the following three lines to enable this feature)
@ -188,8 +198,8 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
"model.decoder.project_in", "model.decoder.final_layer_norm" "model.decoder.project_in", "model.decoder.final_layer_norm"
] ]
# chained attribute names of linear layers in transformer layer module # chained attribute names of linear layers in transformer layer module
# normally, there are four sub lists, for each one the modules in it can be seen as one operation, # normally, there are four sub lists, for each one the modules in it can be seen as one operation,
# and the order should be the order when they are truly executed, in this case (and usually in most cases), # and the order should be the order when they are truly executed, in this case (and usually in most cases),
# they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output # they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output
inside_layer_modules = [ inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
@ -260,14 +270,14 @@ task = SequenceClassificationTask(
"num_samples": 1000, # how many samples will be sampled to evaluation "num_samples": 1000, # how many samples will be sampled to evaluation
"sample_max_len": 1024, # max tokens for each sample "sample_max_len": 1024, # max tokens for each sample
"block_max_len": 2048, # max tokens for each data block "block_max_len": 2048, # max tokens for each data block
# function to load dataset, one must only accept data_name_or_path as input # function to load dataset, one must only accept data_name_or_path as input
# and return datasets.Dataset # and return datasets.Dataset
"load_fn": partial(datasets.load_dataset, name="english"), "load_fn": partial(datasets.load_dataset, name="english"),
# function to preprocess dataset, which is used for datasets.Dataset.map, # function to preprocess dataset, which is used for datasets.Dataset.map,
# must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name] # must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name]
"preprocess_fn": ds_refactor_fn, "preprocess_fn": ds_refactor_fn,
# truncate label when sample's length exceed sample_max_len # truncate label when sample's length exceed sample_max_len
"truncate_prompt": False "truncate_prompt": False
} }
) )
@ -296,7 +306,7 @@ print(
## Supported Models ## Supported Models
> you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`. > you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`.
> >
> for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`. > for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`.
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt | | model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
@ -315,9 +325,17 @@ print(
## Supported Evaluation Tasks ## Supported Evaluation Tasks
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon! Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
## Running tests
Tests can be run with:
```
pytest tests/ -s
```
## Acknowledgement ## Acknowledgement
- Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq). - Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
- Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda). - Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).
[![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date) [![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date)

View file

@ -731,7 +731,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
if quantize_config is None: if quantize_config is None:
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **kwargs) quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
if model_basename is None: if model_basename is None:
if quantize_config.model_file_base_name: if quantize_config.model_file_base_name:

View file

@ -8,7 +8,6 @@ class GeneralQuantLinear(nn.Linear):
out_features=quant_linear_module.outfeatures, out_features=quant_linear_module.outfeatures,
bias=True bias=True
) )
self.infeatures = quant_linear_module.infeatures self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures self.outfeatures = quant_linear_module.outfeatures
self.bits = quant_linear_module.bits self.bits = quant_linear_module.bits
@ -18,15 +17,15 @@ class GeneralQuantLinear(nn.Linear):
self.weight.requires_grad = False self.weight.requires_grad = False
self.weight.data = quant_linear_module.qweight self.weight.data = quant_linear_module.qweight
self.qweight = self.weight self.register_buffer('qweight', quant_linear_module.qweight)
self.bias.data = quant_linear_module.bias self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False self.qweight.requires_grad = False
self.bias.requires_grad = False self.bias.requires_grad = False
self.qzeros = quant_linear_module.qzeros self.register_buffer('qzeros', quant_linear_module.qzeros)
self.scales = quant_linear_module.scales self.register_buffer('scales', quant_linear_module.scales)
self.g_idx = quant_linear_module.g_idx self.register_buffer('g_idx', quant_linear_module.g_idx)
if hasattr(quant_linear_module, "wf"): if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf self.wf = quant_linear_module.wf

View file

@ -1,4 +1,6 @@
from packaging.version import parse as parse_version from packaging.version import parse as parse_version
from logging import getLogger
import torch
try: try:
import triton import triton
@ -14,9 +16,13 @@ try:
except: except:
AUTOGPTQ_CUDA_AVAILABLE = False AUTOGPTQ_CUDA_AVAILABLE = False
logger = getLogger(__name__)
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int): def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
if use_triton: if use_triton:
if torch.version.hip:
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
from ..nn_modules.qlinear.qlinear_triton import QuantLinear from ..nn_modules.qlinear.qlinear_triton import QuantLinear
else: else:
if not desc_act or group_size == -1: if not desc_act or group_size == -1:

View file

@ -30,8 +30,9 @@
// } // }
// #endif // #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION)
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2)); unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui; unsigned int old = *address_as_ui;
@ -76,7 +77,7 @@ __global__ void VecQuant2MatMulKernel(
const int* __restrict__ zeros, const int* __restrict__ zeros,
const int* __restrict__ g_idx, const int* __restrict__ g_idx,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width int zero_width
@ -91,7 +92,7 @@ __global__ void VecQuant3MatMulKernel(
const int* __restrict__ zeros, const int* __restrict__ zeros,
const int* __restrict__ g_idx, const int* __restrict__ g_idx,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width int zero_width
@ -121,7 +122,7 @@ __global__ void VecQuant8MatMulKernel(
const int* __restrict__ zeros, const int* __restrict__ zeros,
const int* __restrict__ g_idx, const int* __restrict__ g_idx,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width int zero_width
@ -135,7 +136,7 @@ __global__ void VecQuant2MatMulKernel_old(
const scalar_t* __restrict__ scales, const scalar_t* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -150,7 +151,7 @@ __global__ void VecQuant3MatMulKernel_old(
const scalar_t* __restrict__ scales, const scalar_t* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -165,7 +166,7 @@ __global__ void VecQuant4MatMulKernel_old(
const scalar_t* __restrict__ scales, const scalar_t* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -180,7 +181,7 @@ __global__ void VecQuant8MatMulKernel_old(
const scalar_t* __restrict__ scales, const scalar_t* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -208,7 +209,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
const float* __restrict__ scales, const float* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -222,7 +223,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
const float* __restrict__ scales, const float* __restrict__ scales,
const int* __restrict__ zeros, const int* __restrict__ zeros,
int batch, int batch,
int vec_height, int vec_height,
int height, int height,
int width, int width,
int zero_width, int zero_width,
@ -269,7 +270,7 @@ void vecquant2matmul_cuda(
vec.type(), "vecquant2matmul_cuda", ([&] { vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>( VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width batch, vec_height, height, width, zero_width
); );
}) })
@ -292,39 +293,39 @@ __global__ void VecQuant2MatMulKernel(
) { ) {
int h = BLOCKHEIGHT2 * blockIdx.x; int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH]; __shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w; int i = width * h + w;
int g_h = h * 16; int g_h = h * 16;
int k; int k;
unsigned int g; unsigned int g;
scalar_t w_tmp; scalar_t w_tmp;
int z_w = w / 16; int z_w = w / 16;
int z_mod = (w % 16) * 2; int z_mod = (w % 16) * 2;
float weight[BLOCKWIDTH]; float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 16); int k_w = (k / 16);
int k_bit = (k % 16) * 2; int k_bit = (k % 16) * 2;
g = as_int(g_idx[g_h + k]); g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
weight[k] = scale * (w_tmp - zero); weight[k] = scale * (w_tmp - zero);
} }
scalar_t res; scalar_t res;
for (int b = 0; b < batch; ++b){ for (int b = 0; b < batch; ++b){
res = 0; res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads(); __syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k]; res += weight[k] * blockvec[k];
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -356,7 +357,7 @@ void vecquant3matmul_cuda(
vec.type(), "vecquant3matmul_cuda", ([&] { vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>( VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width batch, vec_height, height, width, zero_width
); );
}) })
@ -379,15 +380,15 @@ __global__ void VecQuant3MatMulKernel(
) { ) {
int h = BLOCKHEIGHT3 * blockIdx.x; int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH]; __shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w; int i = width * h + w;
int g_h = (h / 3) * 32; int g_h = (h / 3) * 32;
int k; int k;
unsigned int g; unsigned int g;
scalar_t w_tmp; scalar_t w_tmp;
int z_w = (w / 32) * 3; int z_w = (w / 32) * 3;
int z_mod = w % 32; int z_mod = w % 32;
int z_bit; int z_bit;
unsigned int z_tmp; unsigned int z_tmp;
@ -411,14 +412,14 @@ __global__ void VecQuant3MatMulKernel(
z_w += 1; z_w += 1;
} }
} }
float weight[BLOCKWIDTH]; float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 32) * 3; int k_w = (k / 32) * 3;
int k_mod = k % 32; int k_mod = k % 32;
int k_bit; int k_bit;
if (k_mod != 10){ if (k_mod != 10){
if (k_mod != 21){ if (k_mod != 21){
k_bit = k_mod; k_bit = k_mod;
@ -439,7 +440,7 @@ __global__ void VecQuant3MatMulKernel(
k_w += 1; k_w += 1;
} }
} }
g = as_int(g_idx[g_h + k]); g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero; scalar_t zero;
@ -452,7 +453,7 @@ __global__ void VecQuant3MatMulKernel(
} else { } else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
} }
if (k_mod == 10) { if (k_mod == 10) {
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
} else if (k_mod == 21){ } else if (k_mod == 21){
@ -464,12 +465,12 @@ __global__ void VecQuant3MatMulKernel(
} }
scalar_t res; scalar_t res;
for (int b = 0; b < batch; ++b){ for (int b = 0; b < batch; ++b){
res = 0; res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads(); __syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k]; res += weight[k] * blockvec[k];
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -501,7 +502,7 @@ void vecquant4matmul_cuda(
vec.type(), "vecquant4matmul_cuda", ([&] { vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>( VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width batch, vec_height, height, width, zero_width
); );
}) })
@ -524,40 +525,40 @@ __global__ void VecQuant4MatMulKernel(
) { ) {
int h = BLOCKHEIGHT4 * blockIdx.x; int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH]; __shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w; int i = width * h + w;
int g_h = h * 8; int g_h = h * 8;
int k; int k;
unsigned int g; unsigned int g;
scalar_t w_tmp; scalar_t w_tmp;
int z_w = w / 8;
int z_w = w / 8;
int z_mod = (w % 8) * 4; int z_mod = (w % 8) * 4;
float weight[BLOCKWIDTH]; float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 8); int k_w = (k / 8);
int k_bit = (k % 8) * 4; int k_bit = (k % 8) * 4;
g = as_int(g_idx[g_h + k]); g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
weight[k] = scale * (w_tmp - zero); weight[k] = scale * (w_tmp - zero);
} }
scalar_t res; scalar_t res;
for (int b = 0; b < batch; ++b){ for (int b = 0; b < batch; ++b){
res = 0; res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads(); __syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k]; res += weight[k] * blockvec[k];
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -589,7 +590,7 @@ void vecquant8matmul_cuda(
vec.type(), "vecquant8matmul_cuda", ([&] { vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>( VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width batch, vec_height, height, width, zero_width
); );
}) })
@ -612,39 +613,39 @@ __global__ void VecQuant8MatMulKernel(
) { ) {
int h = BLOCKHEIGHT8 * blockIdx.x; int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH]; __shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w; int i = width * h + w;
int g_h = h * 4; int g_h = h * 4;
int k; int k;
unsigned int g; unsigned int g;
scalar_t w_tmp; scalar_t w_tmp;
int z_w = w / 4; int z_w = w / 4;
int z_mod = (w % 4) * 8; int z_mod = (w % 4) * 8;
float weight[BLOCKWIDTH]; float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4); int k_w = (k / 4);
int k_bit = (k % 4) * 8; int k_bit = (k % 4) * 8;
g = as_int(g_idx[g_h + k]); g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
weight[k] = scale * (w_tmp - zero); weight[k] = scale * (w_tmp - zero);
} }
scalar_t res; scalar_t res;
for (int b = 0; b < batch; ++b){ for (int b = 0; b < batch; ++b){
res = 0; res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads(); __syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){ for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k]; res += weight[k] * blockvec[k];
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -711,19 +712,19 @@ __global__ void VecQuant2MatMulKernel_old(
int i = width * h + w; int i = width * h + w;
int g_h = h * 16; int g_h = h * 16;
int k = 0; int k = 0;
int z_w = w / 16; int z_w = w / 16;
int z_mod = (w % 16) * 2; int z_mod = (w % 16) * 2;
unsigned int tmp; unsigned int tmp;
while (k < BLOCKWIDTH) { while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize; int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
@ -740,7 +741,7 @@ __global__ void VecQuant2MatMulKernel_old(
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
i += width; i += width;
k += 16; k += 16;
} }
@ -806,11 +807,11 @@ __global__ void VecQuant3MatMulKernel_old(
int i = width * h + w; int i = width * h + w;
int g_h = (h / 3) * 32; int g_h = (h / 3) * 32;
int k = 0; int k = 0;
int z_w = (w / 32) * 3; int z_w = (w / 32) * 3;
int z_mod = w % 32; int z_mod = w % 32;
int z_bit; int z_bit;
if (z_mod != 10){ if (z_mod != 10){
if (z_mod != 21){ if (z_mod != 21){
z_bit = z_mod; z_bit = z_mod;
@ -831,7 +832,7 @@ __global__ void VecQuant3MatMulKernel_old(
z_w += 1; z_w += 1;
} }
} }
unsigned int tmp1; unsigned int tmp1;
unsigned int tmp2; unsigned int tmp2;
unsigned int tmp; unsigned int tmp;
@ -839,7 +840,7 @@ __global__ void VecQuant3MatMulKernel_old(
while (k < BLOCKWIDTH) { while (k < BLOCKWIDTH) {
tmp1 = as_unsigned(mat[i]); tmp1 = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize; int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero; scalar_t zero;
@ -852,7 +853,7 @@ __global__ void VecQuant3MatMulKernel_old(
} else { } else {
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
} }
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
@ -863,14 +864,14 @@ __global__ void VecQuant3MatMulKernel_old(
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width; i += width;
tmp2 = as_unsigned(mat[i]); tmp2 = as_unsigned(mat[i]);
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
tmp2 >>= 1; tmp2 >>= 1;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11; k += 11;
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
@ -881,14 +882,14 @@ __global__ void VecQuant3MatMulKernel_old(
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width; i += width;
tmp1 = as_unsigned(mat[i]); tmp1 = as_unsigned(mat[i]);
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
tmp1 >>= 2; tmp1 >>= 2;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11; k += 11;
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
@ -899,7 +900,7 @@ __global__ void VecQuant3MatMulKernel_old(
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width; i += width;
k += 10; k += 10;
} }
@ -966,18 +967,18 @@ __global__ void VecQuant4MatMulKernel_old(
int g_h = h * 8; int g_h = h * 8;
int k = 0; int k = 0;
int z_w = w / 8; int z_w = w / 8;
int z_mod = (w % 8) * 4; int z_mod = (w % 8) * 4;
unsigned int tmp; unsigned int tmp;
while (k < BLOCKWIDTH) { while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize; int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
@ -986,7 +987,7 @@ __global__ void VecQuant4MatMulKernel_old(
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
i += width; i += width;
k += 8; k += 8;
} }
@ -1052,24 +1053,24 @@ __global__ void VecQuant8MatMulKernel_old(
int i = width * h + w; int i = width * h + w;
int g_h = h * 4; int g_h = h * 4;
int k = 0; int k = 0;
int z_w = w / 4; int z_w = w / 4;
int z_mod = (w % 4) * 8; int z_mod = (w % 4) * 8;
unsigned int tmp; unsigned int tmp;
while (k < BLOCKWIDTH) { while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize; int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w]; scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
i += width; i += width;
k += 4; k += 4;
} }
@ -1091,7 +1092,7 @@ void vecquant2matmul_faster_cuda_old(
int height = mat.size(0); int height = mat.size(0);
int width = mat.size(1); int width = mat.size(1);
int zero_width = zeros.size(1); int zero_width = zeros.size(1);
dim3 blocks( dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
@ -1143,8 +1144,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
int i = width * h + w; int i = width * h + w;
int g_h = h * 16; int g_h = h * 16;
int k = 0; int k = 0;
int z_w = w / 16; int z_w = w / 16;
int z_mod = (w % 16) * 2; int z_mod = (w % 16) * 2;
float res = 0; float res = 0;
@ -1159,8 +1160,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
float scale_f = scales[g * width + w]; float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f); half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
res2 = {}; std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
@ -1172,7 +1173,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
i += width; i += width;
k += 8; k += 8;
res += __half2float(res2.x) + __half2float(res2.y); res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -1191,7 +1192,7 @@ void vecquant3matmul_faster_cuda_old(
int height = mat.size(0); int height = mat.size(0);
int width = mat.size(1); int width = mat.size(1);
int zero_width = zeros.size(1); int zero_width = zeros.size(1);
dim3 blocks( dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
@ -1243,11 +1244,11 @@ __global__ void VecQuant3MatMulKernelFaster_old(
int i = width * h + w; int i = width * h + w;
int g_h = (h / 3) * 32; int g_h = (h / 3) * 32;
int k = 0; int k = 0;
int z_w = (w / 32) * 3; int z_w = (w / 32) * 3;
int z_mod = w % 32; int z_mod = w % 32;
int z_bit; int z_bit;
if (z_mod != 10){ if (z_mod != 10){
if (z_mod != 21){ if (z_mod != 21){
z_bit = z_mod; z_bit = z_mod;
@ -1293,8 +1294,8 @@ __global__ void VecQuant3MatMulKernelFaster_old(
} else { } else {
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
} }
res2 = {}; std::memset(&res2, 0, sizeof(half2));
tmp1 = as_unsigned(mat[i]); tmp1 = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
@ -1324,7 +1325,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width; i += width;
k += 5; k += 5;
res += __half2float(res2.x) + __half2float(res2.y); res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -1343,7 +1344,7 @@ void vecquant4matmul_faster_cuda_old(
int height = mat.size(0); int height = mat.size(0);
int width = mat.size(1); int width = mat.size(1);
int zero_width = zeros.size(1); int zero_width = zeros.size(1);
dim3 blocks( dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
@ -1396,7 +1397,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
int g_h = h * 8; int g_h = h * 8;
int k = 0; int k = 0;
int z_w = w / 8; int z_w = w / 8;
int z_mod = (w % 8) * 4; int z_mod = (w % 8) * 4;
float res = 0; float res = 0;
@ -1409,10 +1410,15 @@ __global__ void VecQuant4MatMulKernelFaster_old(
while (k < blockwidth2) { while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize; int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w]; float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f); half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
res2 = {}; //std::memset(&res2, 0, sizeof(half2));
//res2 = __float2half2_rn((float)0.);
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
@ -1420,7 +1426,9 @@ __global__ void VecQuant4MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width; i += width;
k += 4; k += 4;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);

View file

@ -30,7 +30,8 @@
// } // }
// #endif // #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION)
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2)); unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
@ -1161,7 +1162,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
half2 scale = __float2half2_rn(scale_f); half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
res2 = {}; std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
@ -1173,7 +1174,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
i += width; i += width;
k += 8; k += 8;
res += __half2float(res2.x) + __half2float(res2.y); res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -1295,7 +1296,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
} }
res2 = {}; std::memset(&res2, 0, sizeof(half2));
tmp1 = as_unsigned(mat[i]); tmp1 = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
@ -1325,7 +1326,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width; i += width;
k += 5; k += 5;
res += __half2float(res2.x) + __half2float(res2.y); res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);
@ -1413,7 +1414,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
half2 scale = __float2half2_rn(scale_f); half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
res2 = {}; std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]); tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
@ -1421,7 +1422,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width; i += width;
k += 4; k += 4;
res += __half2float(res2.x) + __half2float(res2.y); res += __low2float(res2) + __high2float(res2);
} }
atomicAdd(&mul[b * width + w], res); atomicAdd(&mul[b * width + w], res);

View file

@ -18,9 +18,15 @@ if BUILD_CUDA_EXT:
except: except:
print("torch is not installed, please install torch first!") print("torch is not installed, please install torch first!")
sys.exit(-1) sys.exit(-1)
CUDA_VERSION = "".join(torch.version.cuda.split("."))
else: CUDA_VERSION = False
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", "").split(".")) ROCM_VERSION = os.environ.get('ROCM_VERSION', False)
if ROCM_VERSION and not torch.version.hip:
raise ValueError(f"Trying to compile AutoGPTQ for RoCm, but PyTorch {torch.__version__} is installed with no RoCm support.")
if not ROCM_VERSION:
default_cuda_version = "".join(torch.version.cuda.split("."))
CUDA_VERSION = os.environ.get("CUDA_VERSION", default_cuda_version)
common_setup_kwargs = { common_setup_kwargs = {
"version": "0.3.2", "version": "0.3.2",
@ -46,8 +52,13 @@ common_setup_kwargs = {
"python_requires": f">={python_min_version_str}" "python_requires": f">={python_min_version_str}"
} }
if CUDA_VERSION: if BUILD_CUDA_EXT:
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}" if ROCM_VERSION:
common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}"
else:
assert CUDA_VERSION
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}"
requirements = [ requirements = [
"accelerate>=0.19.0", "accelerate>=0.19.0",
@ -72,11 +83,15 @@ include_dirs = ["autogptq_cuda"]
additional_setup_kwargs = dict() additional_setup_kwargs = dict()
if BUILD_CUDA_EXT: if BUILD_CUDA_EXT:
from torch.utils import cpp_extension from torch.utils import cpp_extension
from distutils.sysconfig import get_python_lib
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") if not ROCM_VERSION:
if os.path.isdir(conda_cuda_include_dir): from distutils.sysconfig import get_python_lib
include_dirs.append(conda_cuda_include_dir) conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
print(f"appending conda cuda include dir {conda_cuda_include_dir}")
print("conda_cuda_include_dir", conda_cuda_include_dir)
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
print(f"appending conda cuda include dir {conda_cuda_include_dir}")
extensions = [ extensions = [
cpp_extension.CUDAExtension( cpp_extension.CUDAExtension(
"autogptq_cuda_64", "autogptq_cuda_64",

113
tests/test_q4.py Normal file
View file

@ -0,0 +1,113 @@
import unittest
from parameterized import parameterized
import torch
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
def get_diff(a, ref):
eps = 1e-6
return f"Maxdiff: {(a - ref).abs().max()}, Mean relative diff: {((a - ref).abs() / (ref.abs() + eps)).mean()}"
class TestsQ4CUDA(unittest.TestCase):
REFERENCE_OLD_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.8008, 1.9688, 1.3262, 1.7627, 1.8164, 1.9307,
1.8574, 1.5449, 1.5293, 1.6074, 1.5566, 1.8545, 1.6582, 1.8838, 2.0215,
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
1.1826, 1.8174, 2.1191, 1.6641, 2.0586, 1.6182, 1.7627, 1.7920, 1.4424,
2.0723, 1.6865, 1.2979, 2.0840, 1.6729, 1.9648, 2.1602, 1.6006, 1.2773,
2.2129, 1.8057, 1.7285, 1.6621, 1.6475, 1.4805, 1.7959, 1.5010, 0.8643,
2.6680, 2.0918, 1.8555, 1.9795, 1.3271, 1.8359, 1.6338, 1.9766, 1.7881,
1.6025, 1.7637, 1.7012, 1.7852, 1.5674, 0.8091, 1.7188, 1.6123, 1.8525,
1.4434, 1.9590, 1.5801, 1.4209, 1.7178, 1.8408, 2.4141, 1.9658, 1.4922,
2.1992, 1.9473, 1.8047, 1.2979, 1.6396, 1.6221, 1.5020, 1.9941, 1.7725,
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7197, 1.7686,
1.4160, 1.7275, 2.1738, 1.9609, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
2.1113, 1.7227, 1.5811, 1.7607, 2.2773, 1.8945, 1.4111, 1.5801, 1.7744,
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7803, 1.7744, 1.5312,
1.2695, 1.9209, 2.0469, 1.6777, 2.5215, 1.8389, 1.7598, 1.5498, 1.6807,
1.7324, 1.5938, 1.9268, 1.7734, 1.4463, 2.0391, 2.0527, 2.2129, 1.6787,
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7080, 1.1611, 1.8584,
2.4570, 1.6016, 1.4834, 1.1777, 1.7969, 1.8955, 1.8906, 1.6738, 1.7510,
1.4316, 1.8340, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
1.8887, 1.5137, 1.4648, 1.6406, 1.7188, 2.2656, 1.5801, 2.1484, 2.0625,
2.0098, 1.7549, 1.1768, 1.4385, 2.0723, 1.6172, 1.7832, 1.8301, 1.6064,
1.5215, 1.9297, 2.3750, 2.1504, 1.7070, 1.1289, 1.4473, 1.5674, 1.6836,
2.2930, 1.1221, 1.5557, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
1.6348, 2.3379, 2.4414, 1.8857, 2.0039, 1.4844, 1.5488, 1.6514, 2.3711,
1.9941, 2.3066, 1.4287, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
2.1309, 1.2666, 1.1523, 1.9561, 1.8584, 1.9746, 1.5986, 1.9688, 2.1973,
1.1523, 2.3281, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
REFERENCE_OLD_NO_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.7998, 1.9678, 1.3262, 1.7617, 1.8154, 1.9307,
1.8574, 1.5449, 1.5293, 1.6074, 1.5557, 1.8545, 1.6582, 1.8838, 2.0195,
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
1.1826, 1.8164, 2.1191, 1.6641, 2.0586, 1.6182, 1.7617, 1.7920, 1.4424,
2.0723, 1.6865, 1.2969, 2.0840, 1.6729, 1.9639, 2.1602, 1.5996, 1.2773,
2.2129, 1.8057, 1.7275, 1.6621, 1.6475, 1.4805, 1.7949, 1.5010, 0.8643,
2.6680, 2.0918, 1.8545, 1.9795, 1.3271, 1.8350, 1.6338, 1.9766, 1.7881,
1.6025, 1.7637, 1.7012, 1.7842, 1.5664, 0.8086, 1.7188, 1.6113, 1.8516,
1.4434, 1.9590, 1.5801, 1.4209, 1.7168, 1.8408, 2.4141, 1.9658, 1.4922,
2.1973, 1.9463, 1.8047, 1.2979, 1.6396, 1.6221, 1.5010, 1.9941, 1.7725,
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7188, 1.7686,
1.4160, 1.7266, 2.1738, 1.9600, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
2.1113, 1.7227, 1.5811, 1.7598, 2.2773, 1.8936, 1.4102, 1.5801, 1.7734,
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7793, 1.7744, 1.5312,
1.2695, 1.9209, 2.0469, 1.6777, 2.5195, 1.8389, 1.7598, 1.5498, 1.6797,
1.7324, 1.5928, 1.9258, 1.7734, 1.4463, 2.0391, 2.0508, 2.2129, 1.6787,
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7070, 1.1602, 1.8584,
2.4570, 1.6016, 1.4834, 1.1777, 1.7959, 1.8955, 1.8906, 1.6738, 1.7510,
1.4316, 1.8330, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
1.8887, 1.5137, 1.4648, 1.6406, 1.7178, 2.2637, 1.5801, 2.1484, 2.0605,
2.0098, 1.7539, 1.1768, 1.4375, 2.0723, 1.6162, 1.7832, 1.8291, 1.6064,
1.5215, 1.9297, 2.3750, 2.1504, 1.7061, 1.1289, 1.4473, 1.5674, 1.6836,
2.2930, 1.1221, 1.5547, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
1.6348, 2.3379, 2.4414, 1.8857, 2.0020, 1.4834, 1.5488, 1.6514, 2.3711,
1.9941, 2.3047, 1.4277, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
2.1309, 1.2666, 1.1514, 1.9551, 1.8584, 1.9746, 1.5986, 1.9688, 2.1953,
1.1514, 2.3262, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
@parameterized.expand([False, True])
def test_cuda_old(self, use_half2: bool):
group_size = 128
# test the 256 kernel (in_features % 256 == 0 and out_features % 256 == 0)
m = 1
k = 256
n = 256
device = "cuda"
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=k,
outfeatures=n,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear.use_cuda_fp16 = use_half2
self.assertTrue(linear.autogptq_cuda_available)
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
linear = linear.eval()
linear = linear.to(device)
with torch.no_grad():
res = linear(inp)[0][0]
if use_half2:
reference = self.REFERENCE_OLD_HALF.to(device)
else:
reference = self.REFERENCE_OLD_NO_HALF.to(device)
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))