Merge branch 'main' into xformers_integration
This commit is contained in:
commit
801610367d
9 changed files with 387 additions and 145 deletions
82
.github/workflows/build_wheels_rocm.yml
vendored
Normal file
82
.github/workflows/build_wheels_rocm.yml
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
name: Build AutoGPTQ Wheels with ROCm
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
if: ${{ github.repository_owner == 'PanQiWei' }}
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python: ["3.8", "3.9", "3.10"] # what's the point?
|
||||
rocm: ["5.4.2", "5.5", "5.6"]
|
||||
|
||||
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Free disk space
|
||||
run: |
|
||||
df -h
|
||||
echo "Removing large packages"
|
||||
sudo apt-get remove -y '^dotnet-.*'
|
||||
sudo apt-get remove -y 'php.*'
|
||||
sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel
|
||||
df -h
|
||||
sudo apt-get autoremove -y >/dev/null 2>&1
|
||||
sudo apt-get clean
|
||||
sudo apt-get autoremove -y >/dev/null 2>&1
|
||||
sudo apt-get autoclean -y >/dev/null 2>&1
|
||||
df -h
|
||||
echo "https://github.com/actions/virtual-environments/issues/709"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
df -h
|
||||
echo "remove big /usr/local"
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
|
||||
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
|
||||
sudo rm -rf /usr/share/swift > /dev/null 2>&1
|
||||
df -h
|
||||
|
||||
- name: Set up environment
|
||||
run: |
|
||||
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
|
||||
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
|
||||
elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then
|
||||
export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb
|
||||
else
|
||||
export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb
|
||||
fi
|
||||
|
||||
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/jammy/$ROCM_DL_FILE
|
||||
sudo dpkg -i $ROCM_DL_FILE
|
||||
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends rocthrust-dev
|
||||
|
||||
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm }}
|
||||
python -m pip install --upgrade build setuptools wheel ninja
|
||||
- name: Build wheels
|
||||
run: |
|
||||
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: 'wheels'
|
||||
path: ./dist/*.whl
|
44
README.md
44
README.md
|
@ -19,6 +19,7 @@
|
|||
|
||||
## News or Update
|
||||
|
||||
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
|
||||
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
|
||||
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
|
||||
- 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub.
|
||||
|
@ -31,7 +32,7 @@
|
|||
|
||||
### Inference Speed
|
||||
> The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
|
||||
>
|
||||
>
|
||||
> The quantized model is loaded using the setup that can gain the fastest inference speed.
|
||||
|
||||
| model | GPU | num_beams | fp16 | gptq-int4 |
|
||||
|
@ -53,14 +54,17 @@ For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200
|
|||
|
||||
### Quick Installation
|
||||
You can install the latest stable release of AutoGPTQ from pip:
|
||||
|
||||
```shell
|
||||
pip install auto-gptq
|
||||
```
|
||||
|
||||
Start from v0.2.0, you can download pre-build wheel that satisfied your environment setup from each version's release assets and install it to skip building stage for the fastest installation speed. For example:
|
||||
```shell
|
||||
# firstly, cd the directory where the wheel saved, then execute command below
|
||||
pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # install v0.2.0 auto_gptq pre-build wheel for linux in an environment whose python=3.10 and cuda=11.8
|
||||
```
|
||||
|
||||
#### disable cuda extensions
|
||||
By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using:
|
||||
```shell
|
||||
|
@ -95,6 +99,12 @@ Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch
|
|||
|
||||
Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
|
||||
|
||||
To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
|
||||
|
||||
```
|
||||
ROCM_VERSION=5.6 pip install .
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Quick Tour
|
||||
|
@ -102,7 +112,7 @@ Use `.[triton]` if you want to integrate with triton and it's available on your
|
|||
### Quantization and Inference
|
||||
> warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
|
||||
|
||||
Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
|
||||
Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
|
||||
```python
|
||||
from transformers import AutoTokenizer, TextGenerationPipeline
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
|
@ -125,7 +135,7 @@ examples = [
|
|||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4, # quantize model to 4-bit
|
||||
group_size=128, # it is recommended to set the value to 128
|
||||
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||||
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||||
)
|
||||
|
||||
# load un-quantized model, by default, the model will always be loaded into CPU memory
|
||||
|
@ -140,7 +150,7 @@ model.save_quantized(quantized_model_dir)
|
|||
# save quantized model using safetensors
|
||||
model.save_quantized(quantized_model_dir, use_safetensors=True)
|
||||
|
||||
# push quantized model to Hugging Face Hub.
|
||||
# push quantized model to Hugging Face Hub.
|
||||
# to use use_auth_token=True, Login first via huggingface-cli login.
|
||||
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
|
||||
# (uncomment the following three lines to enable this feature)
|
||||
|
@ -188,8 +198,8 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
"model.decoder.project_in", "model.decoder.final_layer_norm"
|
||||
]
|
||||
# chained attribute names of linear layers in transformer layer module
|
||||
# normally, there are four sub lists, for each one the modules in it can be seen as one operation,
|
||||
# and the order should be the order when they are truly executed, in this case (and usually in most cases),
|
||||
# normally, there are four sub lists, for each one the modules in it can be seen as one operation,
|
||||
# and the order should be the order when they are truly executed, in this case (and usually in most cases),
|
||||
# they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output
|
||||
inside_layer_modules = [
|
||||
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||
|
@ -260,14 +270,14 @@ task = SequenceClassificationTask(
|
|||
"num_samples": 1000, # how many samples will be sampled to evaluation
|
||||
"sample_max_len": 1024, # max tokens for each sample
|
||||
"block_max_len": 2048, # max tokens for each data block
|
||||
# function to load dataset, one must only accept data_name_or_path as input
|
||||
# function to load dataset, one must only accept data_name_or_path as input
|
||||
# and return datasets.Dataset
|
||||
"load_fn": partial(datasets.load_dataset, name="english"),
|
||||
# function to preprocess dataset, which is used for datasets.Dataset.map,
|
||||
"load_fn": partial(datasets.load_dataset, name="english"),
|
||||
# function to preprocess dataset, which is used for datasets.Dataset.map,
|
||||
# must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name]
|
||||
"preprocess_fn": ds_refactor_fn,
|
||||
"preprocess_fn": ds_refactor_fn,
|
||||
# truncate label when sample's length exceed sample_max_len
|
||||
"truncate_prompt": False
|
||||
"truncate_prompt": False
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -296,7 +306,7 @@ print(
|
|||
## Supported Models
|
||||
|
||||
> you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`.
|
||||
>
|
||||
>
|
||||
> for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`.
|
||||
|
||||
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
|
||||
|
@ -315,9 +325,17 @@ print(
|
|||
## Supported Evaluation Tasks
|
||||
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
|
||||
|
||||
## Running tests
|
||||
|
||||
Tests can be run with:
|
||||
|
||||
```
|
||||
pytest tests/ -s
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
- Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
|
||||
- Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).
|
||||
|
||||
|
||||
[](https://star-history.com/#PanQiWei/AutoGPTQ&Date)
|
||||
[](https://star-history.com/#PanQiWei/AutoGPTQ&Date)
|
||||
|
|
|
@ -731,7 +731,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
if quantize_config is None:
|
||||
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **kwargs)
|
||||
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
|
||||
|
||||
if model_basename is None:
|
||||
if quantize_config.model_file_base_name:
|
||||
|
|
|
@ -8,7 +8,6 @@ class GeneralQuantLinear(nn.Linear):
|
|||
out_features=quant_linear_module.outfeatures,
|
||||
bias=True
|
||||
)
|
||||
|
||||
self.infeatures = quant_linear_module.infeatures
|
||||
self.outfeatures = quant_linear_module.outfeatures
|
||||
self.bits = quant_linear_module.bits
|
||||
|
@ -18,15 +17,15 @@ class GeneralQuantLinear(nn.Linear):
|
|||
self.weight.requires_grad = False
|
||||
|
||||
self.weight.data = quant_linear_module.qweight
|
||||
self.qweight = self.weight
|
||||
self.register_buffer('qweight', quant_linear_module.qweight)
|
||||
self.bias.data = quant_linear_module.bias
|
||||
|
||||
self.qweight.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
|
||||
self.qzeros = quant_linear_module.qzeros
|
||||
self.scales = quant_linear_module.scales
|
||||
self.g_idx = quant_linear_module.g_idx
|
||||
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
||||
self.register_buffer('scales', quant_linear_module.scales)
|
||||
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
||||
|
||||
if hasattr(quant_linear_module, "wf"):
|
||||
self.wf = quant_linear_module.wf
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from packaging.version import parse as parse_version
|
||||
from logging import getLogger
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
|
@ -14,9 +16,13 @@ try:
|
|||
except:
|
||||
AUTOGPTQ_CUDA_AVAILABLE = False
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
|
||||
if use_triton:
|
||||
if torch.version.hip:
|
||||
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
|
||||
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
else:
|
||||
if not desc_act or group_size == -1:
|
||||
|
|
|
@ -30,8 +30,9 @@
|
|||
// }
|
||||
// #endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION)
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
|
@ -76,7 +77,7 @@ __global__ void VecQuant2MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -91,7 +92,7 @@ __global__ void VecQuant3MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -121,7 +122,7 @@ __global__ void VecQuant8MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -135,7 +136,7 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -150,7 +151,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -165,7 +166,7 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -180,7 +181,7 @@ __global__ void VecQuant8MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -208,7 +209,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
const float* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -222,7 +223,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
const float* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -269,7 +270,7 @@ void vecquant2matmul_cuda(
|
|||
vec.type(), "vecquant2matmul_cuda", ([&] {
|
||||
VecQuant2MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -292,39 +293,39 @@ __global__ void VecQuant2MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT2 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 16);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 16);
|
||||
int k_bit = (k % 16) * 2;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -356,7 +357,7 @@ void vecquant3matmul_cuda(
|
|||
vec.type(), "vecquant3matmul_cuda", ([&] {
|
||||
VecQuant3MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -379,15 +380,15 @@ __global__ void VecQuant3MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT3 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
unsigned int z_tmp;
|
||||
|
@ -411,14 +412,14 @@ __global__ void VecQuant3MatMulKernel(
|
|||
z_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 32) * 3;
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 32) * 3;
|
||||
int k_mod = k % 32;
|
||||
int k_bit;
|
||||
|
||||
|
||||
if (k_mod != 10){
|
||||
if (k_mod != 21){
|
||||
k_bit = k_mod;
|
||||
|
@ -439,7 +440,7 @@ __global__ void VecQuant3MatMulKernel(
|
|||
k_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero;
|
||||
|
@ -452,7 +453,7 @@ __global__ void VecQuant3MatMulKernel(
|
|||
} else {
|
||||
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
}
|
||||
|
||||
|
||||
if (k_mod == 10) {
|
||||
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
|
||||
} else if (k_mod == 21){
|
||||
|
@ -464,12 +465,12 @@ __global__ void VecQuant3MatMulKernel(
|
|||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -501,7 +502,7 @@ void vecquant4matmul_cuda(
|
|||
vec.type(), "vecquant4matmul_cuda", ([&] {
|
||||
VecQuant4MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -524,40 +525,40 @@ __global__ void VecQuant4MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 8;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
|
||||
int z_w = w / 8;
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 8);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 8);
|
||||
int k_bit = (k % 8) * 4;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -589,7 +590,7 @@ void vecquant8matmul_cuda(
|
|||
vec.type(), "vecquant8matmul_cuda", ([&] {
|
||||
VecQuant8MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -612,39 +613,39 @@ __global__ void VecQuant8MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT8 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 4;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = w / 4;
|
||||
|
||||
int z_w = w / 4;
|
||||
int z_mod = (w % 4) * 8;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 4);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 4);
|
||||
int k_bit = (k % 4) * 8;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -711,19 +712,19 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
|
||||
|
@ -740,7 +741,7 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
|
||||
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 16;
|
||||
}
|
||||
|
@ -806,11 +807,11 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k = 0;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
|
||||
|
||||
if (z_mod != 10){
|
||||
if (z_mod != 21){
|
||||
z_bit = z_mod;
|
||||
|
@ -831,7 +832,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
z_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int tmp2;
|
||||
unsigned int tmp;
|
||||
|
@ -839,7 +840,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero;
|
||||
|
@ -852,7 +853,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
} else {
|
||||
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
}
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -863,14 +864,14 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
tmp2 = as_unsigned(mat[i]);
|
||||
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
|
||||
tmp2 >>= 1;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -881,14 +882,14 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
|
||||
tmp1 >>= 2;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -899,7 +900,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 10;
|
||||
}
|
||||
|
@ -966,18 +967,18 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
|
||||
|
@ -986,7 +987,7 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 8;
|
||||
}
|
||||
|
@ -1052,24 +1053,24 @@ __global__ void VecQuant8MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 4;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 4;
|
||||
|
||||
int z_w = w / 4;
|
||||
int z_mod = (w % 4) * 8;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
@ -1091,7 +1092,7 @@ void vecquant2matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1143,8 +1144,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
float res = 0;
|
||||
|
@ -1159,8 +1160,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
float scale_f = scales[g * width + w];
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
|
||||
|
||||
res2 = {};
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1172,7 +1173,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
|
||||
i += width;
|
||||
k += 8;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -1191,7 +1192,7 @@ void vecquant3matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1243,11 +1244,11 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k = 0;
|
||||
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
|
||||
|
||||
if (z_mod != 10){
|
||||
if (z_mod != 21){
|
||||
z_bit = z_mod;
|
||||
|
@ -1293,8 +1294,8 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
} else {
|
||||
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
|
||||
}
|
||||
|
||||
res2 = {};
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1324,7 +1325,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
|
||||
i += width;
|
||||
k += 5;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -1343,7 +1344,7 @@ void vecquant4matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1396,7 +1397,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
float res = 0;
|
||||
|
@ -1409,10 +1410,15 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
while (k < blockwidth2) {
|
||||
int g = (g_h + (k * 2)) / groupsize;
|
||||
float scale_f = scales[g * width + w];
|
||||
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
|
||||
|
||||
res2 = {};
|
||||
|
||||
//std::memset(&res2, 0, sizeof(half2));
|
||||
|
||||
//res2 = __float2half2_rn((float)0.);
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1420,7 +1426,9 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
|
||||
i += width;
|
||||
k += 4;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
|
|
@ -30,7 +30,8 @@
|
|||
// }
|
||||
// #endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION)
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
|
@ -1161,7 +1162,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
|
||||
|
||||
res2 = {};
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1173,7 +1174,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
|
||||
i += width;
|
||||
k += 8;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -1295,7 +1296,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
|
||||
}
|
||||
|
||||
res2 = {};
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1325,7 +1326,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
|
||||
i += width;
|
||||
k += 5;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -1413,7 +1414,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
|
||||
|
||||
res2 = {};
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
|
||||
|
@ -1421,7 +1422,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
|
||||
i += width;
|
||||
k += 4;
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
|
35
setup.py
35
setup.py
|
@ -18,9 +18,15 @@ if BUILD_CUDA_EXT:
|
|||
except:
|
||||
print("torch is not installed, please install torch first!")
|
||||
sys.exit(-1)
|
||||
CUDA_VERSION = "".join(torch.version.cuda.split("."))
|
||||
else:
|
||||
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", "").split("."))
|
||||
|
||||
CUDA_VERSION = False
|
||||
ROCM_VERSION = os.environ.get('ROCM_VERSION', False)
|
||||
if ROCM_VERSION and not torch.version.hip:
|
||||
raise ValueError(f"Trying to compile AutoGPTQ for RoCm, but PyTorch {torch.__version__} is installed with no RoCm support.")
|
||||
|
||||
if not ROCM_VERSION:
|
||||
default_cuda_version = "".join(torch.version.cuda.split("."))
|
||||
CUDA_VERSION = os.environ.get("CUDA_VERSION", default_cuda_version)
|
||||
|
||||
common_setup_kwargs = {
|
||||
"version": "0.3.2",
|
||||
|
@ -46,8 +52,13 @@ common_setup_kwargs = {
|
|||
"python_requires": f">={python_min_version_str}"
|
||||
}
|
||||
|
||||
if CUDA_VERSION:
|
||||
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}"
|
||||
if BUILD_CUDA_EXT:
|
||||
if ROCM_VERSION:
|
||||
common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}"
|
||||
else:
|
||||
assert CUDA_VERSION
|
||||
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}"
|
||||
|
||||
|
||||
requirements = [
|
||||
"accelerate>=0.19.0",
|
||||
|
@ -72,11 +83,15 @@ include_dirs = ["autogptq_cuda"]
|
|||
additional_setup_kwargs = dict()
|
||||
if BUILD_CUDA_EXT:
|
||||
from torch.utils import cpp_extension
|
||||
from distutils.sysconfig import get_python_lib
|
||||
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
||||
if os.path.isdir(conda_cuda_include_dir):
|
||||
include_dirs.append(conda_cuda_include_dir)
|
||||
print(f"appending conda cuda include dir {conda_cuda_include_dir}")
|
||||
|
||||
if not ROCM_VERSION:
|
||||
from distutils.sysconfig import get_python_lib
|
||||
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
||||
|
||||
print("conda_cuda_include_dir", conda_cuda_include_dir)
|
||||
if os.path.isdir(conda_cuda_include_dir):
|
||||
include_dirs.append(conda_cuda_include_dir)
|
||||
print(f"appending conda cuda include dir {conda_cuda_include_dir}")
|
||||
extensions = [
|
||||
cpp_extension.CUDAExtension(
|
||||
"autogptq_cuda_64",
|
||||
|
|
113
tests/test_q4.py
Normal file
113
tests/test_q4.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import unittest
|
||||
from parameterized import parameterized
|
||||
import torch
|
||||
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
|
||||
|
||||
def get_diff(a, ref):
|
||||
eps = 1e-6
|
||||
return f"Maxdiff: {(a - ref).abs().max()}, Mean relative diff: {((a - ref).abs() / (ref.abs() + eps)).mean()}"
|
||||
|
||||
class TestsQ4CUDA(unittest.TestCase):
|
||||
|
||||
REFERENCE_OLD_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.8008, 1.9688, 1.3262, 1.7627, 1.8164, 1.9307,
|
||||
1.8574, 1.5449, 1.5293, 1.6074, 1.5566, 1.8545, 1.6582, 1.8838, 2.0215,
|
||||
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
|
||||
1.1826, 1.8174, 2.1191, 1.6641, 2.0586, 1.6182, 1.7627, 1.7920, 1.4424,
|
||||
2.0723, 1.6865, 1.2979, 2.0840, 1.6729, 1.9648, 2.1602, 1.6006, 1.2773,
|
||||
2.2129, 1.8057, 1.7285, 1.6621, 1.6475, 1.4805, 1.7959, 1.5010, 0.8643,
|
||||
2.6680, 2.0918, 1.8555, 1.9795, 1.3271, 1.8359, 1.6338, 1.9766, 1.7881,
|
||||
1.6025, 1.7637, 1.7012, 1.7852, 1.5674, 0.8091, 1.7188, 1.6123, 1.8525,
|
||||
1.4434, 1.9590, 1.5801, 1.4209, 1.7178, 1.8408, 2.4141, 1.9658, 1.4922,
|
||||
2.1992, 1.9473, 1.8047, 1.2979, 1.6396, 1.6221, 1.5020, 1.9941, 1.7725,
|
||||
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7197, 1.7686,
|
||||
1.4160, 1.7275, 2.1738, 1.9609, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
|
||||
2.1113, 1.7227, 1.5811, 1.7607, 2.2773, 1.8945, 1.4111, 1.5801, 1.7744,
|
||||
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
|
||||
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7803, 1.7744, 1.5312,
|
||||
1.2695, 1.9209, 2.0469, 1.6777, 2.5215, 1.8389, 1.7598, 1.5498, 1.6807,
|
||||
1.7324, 1.5938, 1.9268, 1.7734, 1.4463, 2.0391, 2.0527, 2.2129, 1.6787,
|
||||
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7080, 1.1611, 1.8584,
|
||||
2.4570, 1.6016, 1.4834, 1.1777, 1.7969, 1.8955, 1.8906, 1.6738, 1.7510,
|
||||
1.4316, 1.8340, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
|
||||
1.8887, 1.5137, 1.4648, 1.6406, 1.7188, 2.2656, 1.5801, 2.1484, 2.0625,
|
||||
2.0098, 1.7549, 1.1768, 1.4385, 2.0723, 1.6172, 1.7832, 1.8301, 1.6064,
|
||||
1.5215, 1.9297, 2.3750, 2.1504, 1.7070, 1.1289, 1.4473, 1.5674, 1.6836,
|
||||
2.2930, 1.1221, 1.5557, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
|
||||
1.6348, 2.3379, 2.4414, 1.8857, 2.0039, 1.4844, 1.5488, 1.6514, 2.3711,
|
||||
1.9941, 2.3066, 1.4287, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
|
||||
2.1309, 1.2666, 1.1523, 1.9561, 1.8584, 1.9746, 1.5986, 1.9688, 2.1973,
|
||||
1.1523, 2.3281, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
|
||||
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
|
||||
|
||||
REFERENCE_OLD_NO_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.7998, 1.9678, 1.3262, 1.7617, 1.8154, 1.9307,
|
||||
1.8574, 1.5449, 1.5293, 1.6074, 1.5557, 1.8545, 1.6582, 1.8838, 2.0195,
|
||||
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
|
||||
1.1826, 1.8164, 2.1191, 1.6641, 2.0586, 1.6182, 1.7617, 1.7920, 1.4424,
|
||||
2.0723, 1.6865, 1.2969, 2.0840, 1.6729, 1.9639, 2.1602, 1.5996, 1.2773,
|
||||
2.2129, 1.8057, 1.7275, 1.6621, 1.6475, 1.4805, 1.7949, 1.5010, 0.8643,
|
||||
2.6680, 2.0918, 1.8545, 1.9795, 1.3271, 1.8350, 1.6338, 1.9766, 1.7881,
|
||||
1.6025, 1.7637, 1.7012, 1.7842, 1.5664, 0.8086, 1.7188, 1.6113, 1.8516,
|
||||
1.4434, 1.9590, 1.5801, 1.4209, 1.7168, 1.8408, 2.4141, 1.9658, 1.4922,
|
||||
2.1973, 1.9463, 1.8047, 1.2979, 1.6396, 1.6221, 1.5010, 1.9941, 1.7725,
|
||||
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7188, 1.7686,
|
||||
1.4160, 1.7266, 2.1738, 1.9600, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
|
||||
2.1113, 1.7227, 1.5811, 1.7598, 2.2773, 1.8936, 1.4102, 1.5801, 1.7734,
|
||||
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
|
||||
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7793, 1.7744, 1.5312,
|
||||
1.2695, 1.9209, 2.0469, 1.6777, 2.5195, 1.8389, 1.7598, 1.5498, 1.6797,
|
||||
1.7324, 1.5928, 1.9258, 1.7734, 1.4463, 2.0391, 2.0508, 2.2129, 1.6787,
|
||||
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7070, 1.1602, 1.8584,
|
||||
2.4570, 1.6016, 1.4834, 1.1777, 1.7959, 1.8955, 1.8906, 1.6738, 1.7510,
|
||||
1.4316, 1.8330, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
|
||||
1.8887, 1.5137, 1.4648, 1.6406, 1.7178, 2.2637, 1.5801, 2.1484, 2.0605,
|
||||
2.0098, 1.7539, 1.1768, 1.4375, 2.0723, 1.6162, 1.7832, 1.8291, 1.6064,
|
||||
1.5215, 1.9297, 2.3750, 2.1504, 1.7061, 1.1289, 1.4473, 1.5674, 1.6836,
|
||||
2.2930, 1.1221, 1.5547, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
|
||||
1.6348, 2.3379, 2.4414, 1.8857, 2.0020, 1.4834, 1.5488, 1.6514, 2.3711,
|
||||
1.9941, 2.3047, 1.4277, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
|
||||
2.1309, 1.2666, 1.1514, 1.9551, 1.8584, 1.9746, 1.5986, 1.9688, 2.1953,
|
||||
1.1514, 2.3262, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
|
||||
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
|
||||
|
||||
@parameterized.expand([False, True])
|
||||
def test_cuda_old(self, use_half2: bool):
|
||||
|
||||
group_size = 128
|
||||
|
||||
# test the 256 kernel (in_features % 256 == 0 and out_features % 256 == 0)
|
||||
m = 1
|
||||
k = 256
|
||||
n = 256
|
||||
device = "cuda"
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
group_size=group_size,
|
||||
infeatures=k,
|
||||
outfeatures=n,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
|
||||
linear.scales = linear.scales + 0.002
|
||||
linear.use_cuda_fp16 = use_half2
|
||||
self.assertTrue(linear.autogptq_cuda_available)
|
||||
|
||||
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
|
||||
|
||||
linear = linear.eval()
|
||||
linear = linear.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
res = linear(inp)[0][0]
|
||||
|
||||
if use_half2:
|
||||
reference = self.REFERENCE_OLD_HALF.to(device)
|
||||
else:
|
||||
reference = self.REFERENCE_OLD_NO_HALF.to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))
|
Loading…
Add table
Reference in a new issue