diff --git a/.github/workflows/build_wheels_rocm.yml b/.github/workflows/build_wheels_rocm.yml new file mode 100644 index 0000000..a4331ca --- /dev/null +++ b/.github/workflows/build_wheels_rocm.yml @@ -0,0 +1,82 @@ +name: Build AutoGPTQ Wheels with ROCm + +on: workflow_dispatch + +jobs: + build_wheels: + if: ${{ github.repository_owner == 'PanQiWei' }} + + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.8", "3.9", "3.10"] # what's the point? + rocm: ["5.4.2", "5.5", "5.6"] + + name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }} + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash + + steps: + - uses: actions/checkout@v3 + with: + ref: 'main' + + - uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python }} + + - name: Free disk space + run: | + df -h + echo "Removing large packages" + sudo apt-get remove -y '^dotnet-.*' + sudo apt-get remove -y 'php.*' + sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel + df -h + sudo apt-get autoremove -y >/dev/null 2>&1 + sudo apt-get clean + sudo apt-get autoremove -y >/dev/null 2>&1 + sudo apt-get autoclean -y >/dev/null 2>&1 + df -h + echo "https://github.com/actions/virtual-environments/issues/709" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h + echo "remove big /usr/local" + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf /usr/local/lib/android >/dev/null 2>&1 + df -h + sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1 + sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1 + sudo rm -rf /usr/share/swift > /dev/null 2>&1 + df -h + + - name: Set up environment + run: | + if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then + export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb + elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then + export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb + else + export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb + fi + + curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/jammy/$ROCM_DL_FILE + sudo dpkg -i $ROCM_DL_FILE + sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends rocthrust-dev + + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm }} + python -m pip install --upgrade build setuptools wheel ninja + - name: Build wheels + run: | + ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel + - uses: actions/upload-artifact@v3 + with: + name: 'wheels' + path: ./dist/*.whl diff --git a/README.md b/README.md index d60361b..a59195c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ ## News or Update +- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions. - 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`. - 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc. - 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub. @@ -31,7 +32,7 @@ ### Inference Speed > The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better). -> +> > The quantized model is loaded using the setup that can gain the fastest inference speed. | model | GPU | num_beams | fp16 | gptq-int4 | @@ -53,14 +54,17 @@ For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200 ### Quick Installation You can install the latest stable release of AutoGPTQ from pip: + ```shell pip install auto-gptq ``` + Start from v0.2.0, you can download pre-build wheel that satisfied your environment setup from each version's release assets and install it to skip building stage for the fastest installation speed. For example: ```shell # firstly, cd the directory where the wheel saved, then execute command below pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # install v0.2.0 auto_gptq pre-build wheel for linux in an environment whose python=3.10 and cuda=11.8 ``` + #### disable cuda extensions By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using: ```shell @@ -95,6 +99,12 @@ Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch Use `.[triton]` if you want to integrate with triton and it's available on your operating system. +To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example: + +``` +ROCM_VERSION=5.6 pip install . +``` + ## Quick Tour @@ -102,7 +112,7 @@ Use `.[triton]` if you want to integrate with triton and it's available on your ### Quantization and Inference > warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good. -Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization: +Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization: ```python from transformers import AutoTokenizer, TextGenerationPipeline from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig @@ -125,7 +135,7 @@ examples = [ quantize_config = BaseQuantizeConfig( bits=4, # quantize model to 4-bit group_size=128, # it is recommended to set the value to 128 - desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad ) # load un-quantized model, by default, the model will always be loaded into CPU memory @@ -140,7 +150,7 @@ model.save_quantized(quantized_model_dir) # save quantized model using safetensors model.save_quantized(quantized_model_dir, use_safetensors=True) -# push quantized model to Hugging Face Hub. +# push quantized model to Hugging Face Hub. # to use use_auth_token=True, Login first via huggingface-cli login. # or pass explcit token with: use_auth_token="hf_xxxxxxx" # (uncomment the following three lines to enable this feature) @@ -188,8 +198,8 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM): "model.decoder.project_in", "model.decoder.final_layer_norm" ] # chained attribute names of linear layers in transformer layer module - # normally, there are four sub lists, for each one the modules in it can be seen as one operation, - # and the order should be the order when they are truly executed, in this case (and usually in most cases), + # normally, there are four sub lists, for each one the modules in it can be seen as one operation, + # and the order should be the order when they are truly executed, in this case (and usually in most cases), # they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output inside_layer_modules = [ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], @@ -260,14 +270,14 @@ task = SequenceClassificationTask( "num_samples": 1000, # how many samples will be sampled to evaluation "sample_max_len": 1024, # max tokens for each sample "block_max_len": 2048, # max tokens for each data block - # function to load dataset, one must only accept data_name_or_path as input + # function to load dataset, one must only accept data_name_or_path as input # and return datasets.Dataset - "load_fn": partial(datasets.load_dataset, name="english"), - # function to preprocess dataset, which is used for datasets.Dataset.map, + "load_fn": partial(datasets.load_dataset, name="english"), + # function to preprocess dataset, which is used for datasets.Dataset.map, # must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name] - "preprocess_fn": ds_refactor_fn, + "preprocess_fn": ds_refactor_fn, # truncate label when sample's length exceed sample_max_len - "truncate_prompt": False + "truncate_prompt": False } ) @@ -296,7 +306,7 @@ print( ## Supported Models > you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`. -> +> > for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`. | model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt | @@ -315,9 +325,17 @@ print( ## Supported Evaluation Tasks Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon! +## Running tests + +Tests can be run with: + +``` +pytest tests/ -s +``` + ## Acknowledgement - Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq). - Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda). -[![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date) \ No newline at end of file +[![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date) diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index eee3f60..64cbcab 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -731,7 +731,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): raise TypeError(f"{config.model_type} isn't supported yet.") if quantize_config is None: - quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **kwargs) + quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs) if model_basename is None: if quantize_config.model_file_base_name: diff --git a/auto_gptq/nn_modules/qlinear/__init__.py b/auto_gptq/nn_modules/qlinear/__init__.py index cce73ea..bbd4c0c 100644 --- a/auto_gptq/nn_modules/qlinear/__init__.py +++ b/auto_gptq/nn_modules/qlinear/__init__.py @@ -8,7 +8,6 @@ class GeneralQuantLinear(nn.Linear): out_features=quant_linear_module.outfeatures, bias=True ) - self.infeatures = quant_linear_module.infeatures self.outfeatures = quant_linear_module.outfeatures self.bits = quant_linear_module.bits @@ -18,15 +17,15 @@ class GeneralQuantLinear(nn.Linear): self.weight.requires_grad = False self.weight.data = quant_linear_module.qweight - self.qweight = self.weight + self.register_buffer('qweight', quant_linear_module.qweight) self.bias.data = quant_linear_module.bias self.qweight.requires_grad = False self.bias.requires_grad = False - self.qzeros = quant_linear_module.qzeros - self.scales = quant_linear_module.scales - self.g_idx = quant_linear_module.g_idx + self.register_buffer('qzeros', quant_linear_module.qzeros) + self.register_buffer('scales', quant_linear_module.scales) + self.register_buffer('g_idx', quant_linear_module.g_idx) if hasattr(quant_linear_module, "wf"): self.wf = quant_linear_module.wf diff --git a/auto_gptq/utils/import_utils.py b/auto_gptq/utils/import_utils.py index 4230654..01b8b44 100644 --- a/auto_gptq/utils/import_utils.py +++ b/auto_gptq/utils/import_utils.py @@ -1,4 +1,6 @@ from packaging.version import parse as parse_version +from logging import getLogger +import torch try: import triton @@ -14,9 +16,13 @@ try: except: AUTOGPTQ_CUDA_AVAILABLE = False +logger = getLogger(__name__) def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int): if use_triton: + if torch.version.hip: + logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.") + from ..nn_modules.qlinear.qlinear_triton import QuantLinear else: if not desc_act or group_size == -1: diff --git a/autogptq_cuda/autogptq_cuda_kernel_256.cu b/autogptq_cuda/autogptq_cuda_kernel_256.cu index 3fe8a7a..e6e94f4 100644 --- a/autogptq_cuda/autogptq_cuda_kernel_256.cu +++ b/autogptq_cuda/autogptq_cuda_kernel_256.cu @@ -30,8 +30,9 @@ // } // #endif -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION) // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh + __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; @@ -76,7 +77,7 @@ __global__ void VecQuant2MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -91,7 +92,7 @@ __global__ void VecQuant3MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -121,7 +122,7 @@ __global__ void VecQuant8MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -135,7 +136,7 @@ __global__ void VecQuant2MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -150,7 +151,7 @@ __global__ void VecQuant3MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -165,7 +166,7 @@ __global__ void VecQuant4MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -180,7 +181,7 @@ __global__ void VecQuant8MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -208,7 +209,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -222,7 +223,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -269,7 +270,7 @@ void vecquant2matmul_cuda( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -292,39 +293,39 @@ __global__ void VecQuant2MatMulKernel( ) { int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 16; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 16); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); int k_bit = (k % 16) * 2; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); - + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -356,7 +357,7 @@ void vecquant3matmul_cuda( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -379,15 +380,15 @@ __global__ void VecQuant3MatMulKernel( ) { int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = (h / 3) * 32; int k; unsigned int g; scalar_t w_tmp; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; unsigned int z_tmp; @@ -411,14 +412,14 @@ __global__ void VecQuant3MatMulKernel( z_w += 1; } } - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 32) * 3; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; int k_mod = k % 32; int k_bit; - + if (k_mod != 10){ if (k_mod != 21){ k_bit = k_mod; @@ -439,7 +440,7 @@ __global__ void VecQuant3MatMulKernel( k_w += 1; } } - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero; @@ -452,7 +453,7 @@ __global__ void VecQuant3MatMulKernel( } else { zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } - + if (k_mod == 10) { w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); } else if (k_mod == 21){ @@ -464,12 +465,12 @@ __global__ void VecQuant3MatMulKernel( } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -501,7 +502,7 @@ void vecquant4matmul_cuda( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -524,40 +525,40 @@ __global__ void VecQuant4MatMulKernel( ) { int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 8; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 8; + + int z_w = w / 8; int z_mod = (w % 8) * 4; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 8); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); int k_bit = (k % 8) * 4; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); - + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -589,7 +590,7 @@ void vecquant8matmul_cuda( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -612,39 +613,39 @@ __global__ void VecQuant8MatMulKernel( ) { int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 4; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 4); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); int k_bit = (k % 4) * 8; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); - + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -711,19 +712,19 @@ __global__ void VecQuant2MatMulKernel_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); - + res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; @@ -740,7 +741,7 @@ __global__ void VecQuant2MatMulKernel_old( res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; - + i += width; k += 16; } @@ -806,11 +807,11 @@ __global__ void VecQuant3MatMulKernel_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -831,7 +832,7 @@ __global__ void VecQuant3MatMulKernel_old( z_w += 1; } } - + unsigned int tmp1; unsigned int tmp2; unsigned int tmp; @@ -839,7 +840,7 @@ __global__ void VecQuant3MatMulKernel_old( while (k < BLOCKWIDTH) { tmp1 = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero; @@ -852,7 +853,7 @@ __global__ void VecQuant3MatMulKernel_old( } else { zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -863,14 +864,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp2 >>= 1; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -881,14 +882,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp1 >>= 2; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -899,7 +900,7 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; k += 10; } @@ -966,18 +967,18 @@ __global__ void VecQuant4MatMulKernel_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); - + res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; @@ -986,7 +987,7 @@ __global__ void VecQuant4MatMulKernel_old( res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; - + i += width; k += 8; } @@ -1052,24 +1053,24 @@ __global__ void VecQuant8MatMulKernel_old( int i = width * h + w; int g_h = h * 4; int k = 0; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; unsigned int tmp; - while (k < BLOCKWIDTH) { + while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); - + res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; - + i += width; k += 4; } @@ -1091,7 +1092,7 @@ void vecquant2matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1143,8 +1144,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; float res = 0; @@ -1159,8 +1160,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); - - res2 = {}; + + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); @@ -1172,7 +1173,7 @@ __global__ void VecQuant2MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); i += width; k += 8; - res += __half2float(res2.x) + __half2float(res2.y); + res += __low2float(res2) + __high2float(res2); } atomicAdd(&mul[b * width + w], res); @@ -1191,7 +1192,7 @@ void vecquant3matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1243,11 +1244,11 @@ __global__ void VecQuant3MatMulKernelFaster_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -1293,8 +1294,8 @@ __global__ void VecQuant3MatMulKernelFaster_old( } else { zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); } - - res2 = {}; + + std::memset(&res2, 0, sizeof(half2)); tmp1 = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); @@ -1324,7 +1325,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); i += width; k += 5; - res += __half2float(res2.x) + __half2float(res2.y); + res += __low2float(res2) + __high2float(res2); } atomicAdd(&mul[b * width + w], res); @@ -1343,7 +1344,7 @@ void vecquant4matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1396,7 +1397,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; float res = 0; @@ -1409,10 +1410,15 @@ __global__ void VecQuant4MatMulKernelFaster_old( while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; + half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); - - res2 = {}; + + //std::memset(&res2, 0, sizeof(half2)); + + //res2 = __float2half2_rn((float)0.); + + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); @@ -1420,7 +1426,9 @@ __global__ void VecQuant4MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); i += width; k += 4; - res += __half2float(res2.x) + __half2float(res2.y); + + res += __low2float(res2) + __high2float(res2); + } atomicAdd(&mul[b * width + w], res); diff --git a/autogptq_cuda/autogptq_cuda_kernel_64.cu b/autogptq_cuda/autogptq_cuda_kernel_64.cu index 6cc6036..e0a1b87 100644 --- a/autogptq_cuda/autogptq_cuda_kernel_64.cu +++ b/autogptq_cuda/autogptq_cuda_kernel_64.cu @@ -30,7 +30,8 @@ // } // #endif -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(ROCM_VERSION) // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); @@ -1161,7 +1162,7 @@ __global__ void VecQuant2MatMulKernelFaster_old( half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); - res2 = {}; + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); @@ -1173,7 +1174,7 @@ __global__ void VecQuant2MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); i += width; k += 8; - res += __half2float(res2.x) + __half2float(res2.y); + res += __low2float(res2) + __high2float(res2); } atomicAdd(&mul[b * width + w], res); @@ -1295,7 +1296,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); } - res2 = {}; + std::memset(&res2, 0, sizeof(half2)); tmp1 = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); @@ -1325,7 +1326,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); i += width; k += 5; - res += __half2float(res2.x) + __half2float(res2.y); + res += __low2float(res2) + __high2float(res2); } atomicAdd(&mul[b * width + w], res); @@ -1413,7 +1414,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); - res2 = {}; + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); @@ -1421,7 +1422,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); i += width; k += 4; - res += __half2float(res2.x) + __half2float(res2.y); + res += __low2float(res2) + __high2float(res2); } atomicAdd(&mul[b * width + w], res); diff --git a/setup.py b/setup.py index df01edc..cd4c5bf 100644 --- a/setup.py +++ b/setup.py @@ -18,9 +18,15 @@ if BUILD_CUDA_EXT: except: print("torch is not installed, please install torch first!") sys.exit(-1) - CUDA_VERSION = "".join(torch.version.cuda.split(".")) -else: - CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", "").split(".")) + + CUDA_VERSION = False + ROCM_VERSION = os.environ.get('ROCM_VERSION', False) + if ROCM_VERSION and not torch.version.hip: + raise ValueError(f"Trying to compile AutoGPTQ for RoCm, but PyTorch {torch.__version__} is installed with no RoCm support.") + + if not ROCM_VERSION: + default_cuda_version = "".join(torch.version.cuda.split(".")) + CUDA_VERSION = os.environ.get("CUDA_VERSION", default_cuda_version) common_setup_kwargs = { "version": "0.3.2", @@ -46,8 +52,13 @@ common_setup_kwargs = { "python_requires": f">={python_min_version_str}" } -if CUDA_VERSION: - common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}" +if BUILD_CUDA_EXT: + if ROCM_VERSION: + common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}" + else: + assert CUDA_VERSION + common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}" + requirements = [ "accelerate>=0.19.0", @@ -72,11 +83,15 @@ include_dirs = ["autogptq_cuda"] additional_setup_kwargs = dict() if BUILD_CUDA_EXT: from torch.utils import cpp_extension - from distutils.sysconfig import get_python_lib - conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") - if os.path.isdir(conda_cuda_include_dir): - include_dirs.append(conda_cuda_include_dir) - print(f"appending conda cuda include dir {conda_cuda_include_dir}") + + if not ROCM_VERSION: + from distutils.sysconfig import get_python_lib + conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") + + print("conda_cuda_include_dir", conda_cuda_include_dir) + if os.path.isdir(conda_cuda_include_dir): + include_dirs.append(conda_cuda_include_dir) + print(f"appending conda cuda include dir {conda_cuda_include_dir}") extensions = [ cpp_extension.CUDAExtension( "autogptq_cuda_64", diff --git a/tests/test_q4.py b/tests/test_q4.py new file mode 100644 index 0000000..5d9b165 --- /dev/null +++ b/tests/test_q4.py @@ -0,0 +1,113 @@ +import unittest +from parameterized import parameterized +import torch +from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + +def get_diff(a, ref): + eps = 1e-6 + return f"Maxdiff: {(a - ref).abs().max()}, Mean relative diff: {((a - ref).abs() / (ref.abs() + eps)).mean()}" + +class TestsQ4CUDA(unittest.TestCase): + + REFERENCE_OLD_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.8008, 1.9688, 1.3262, 1.7627, 1.8164, 1.9307, + 1.8574, 1.5449, 1.5293, 1.6074, 1.5566, 1.8545, 1.6582, 1.8838, 2.0215, + 1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756, + 1.1826, 1.8174, 2.1191, 1.6641, 2.0586, 1.6182, 1.7627, 1.7920, 1.4424, + 2.0723, 1.6865, 1.2979, 2.0840, 1.6729, 1.9648, 2.1602, 1.6006, 1.2773, + 2.2129, 1.8057, 1.7285, 1.6621, 1.6475, 1.4805, 1.7959, 1.5010, 0.8643, + 2.6680, 2.0918, 1.8555, 1.9795, 1.3271, 1.8359, 1.6338, 1.9766, 1.7881, + 1.6025, 1.7637, 1.7012, 1.7852, 1.5674, 0.8091, 1.7188, 1.6123, 1.8525, + 1.4434, 1.9590, 1.5801, 1.4209, 1.7178, 1.8408, 2.4141, 1.9658, 1.4922, + 2.1992, 1.9473, 1.8047, 1.2979, 1.6396, 1.6221, 1.5020, 1.9941, 1.7725, + 1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7197, 1.7686, + 1.4160, 1.7275, 2.1738, 1.9609, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002, + 2.1113, 1.7227, 1.5811, 1.7607, 2.2773, 1.8945, 1.4111, 1.5801, 1.7744, + 2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840, + 2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7803, 1.7744, 1.5312, + 1.2695, 1.9209, 2.0469, 1.6777, 2.5215, 1.8389, 1.7598, 1.5498, 1.6807, + 1.7324, 1.5938, 1.9268, 1.7734, 1.4463, 2.0391, 2.0527, 2.2129, 1.6787, + 2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7080, 1.1611, 1.8584, + 2.4570, 1.6016, 1.4834, 1.1777, 1.7969, 1.8955, 1.8906, 1.6738, 1.7510, + 1.4316, 1.8340, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629, + 1.8887, 1.5137, 1.4648, 1.6406, 1.7188, 2.2656, 1.5801, 2.1484, 2.0625, + 2.0098, 1.7549, 1.1768, 1.4385, 2.0723, 1.6172, 1.7832, 1.8301, 1.6064, + 1.5215, 1.9297, 2.3750, 2.1504, 1.7070, 1.1289, 1.4473, 1.5674, 1.6836, + 2.2930, 1.1221, 1.5557, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988, + 1.6348, 2.3379, 2.4414, 1.8857, 2.0039, 1.4844, 1.5488, 1.6514, 2.3711, + 1.9941, 2.3066, 1.4287, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502, + 2.1309, 1.2666, 1.1523, 1.9561, 1.8584, 1.9746, 1.5986, 1.9688, 2.1973, + 1.1523, 2.3281, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523, + 1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16) + + REFERENCE_OLD_NO_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.7998, 1.9678, 1.3262, 1.7617, 1.8154, 1.9307, + 1.8574, 1.5449, 1.5293, 1.6074, 1.5557, 1.8545, 1.6582, 1.8838, 2.0195, + 1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756, + 1.1826, 1.8164, 2.1191, 1.6641, 2.0586, 1.6182, 1.7617, 1.7920, 1.4424, + 2.0723, 1.6865, 1.2969, 2.0840, 1.6729, 1.9639, 2.1602, 1.5996, 1.2773, + 2.2129, 1.8057, 1.7275, 1.6621, 1.6475, 1.4805, 1.7949, 1.5010, 0.8643, + 2.6680, 2.0918, 1.8545, 1.9795, 1.3271, 1.8350, 1.6338, 1.9766, 1.7881, + 1.6025, 1.7637, 1.7012, 1.7842, 1.5664, 0.8086, 1.7188, 1.6113, 1.8516, + 1.4434, 1.9590, 1.5801, 1.4209, 1.7168, 1.8408, 2.4141, 1.9658, 1.4922, + 2.1973, 1.9463, 1.8047, 1.2979, 1.6396, 1.6221, 1.5010, 1.9941, 1.7725, + 1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7188, 1.7686, + 1.4160, 1.7266, 2.1738, 1.9600, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002, + 2.1113, 1.7227, 1.5811, 1.7598, 2.2773, 1.8936, 1.4102, 1.5801, 1.7734, + 2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840, + 2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7793, 1.7744, 1.5312, + 1.2695, 1.9209, 2.0469, 1.6777, 2.5195, 1.8389, 1.7598, 1.5498, 1.6797, + 1.7324, 1.5928, 1.9258, 1.7734, 1.4463, 2.0391, 2.0508, 2.2129, 1.6787, + 2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7070, 1.1602, 1.8584, + 2.4570, 1.6016, 1.4834, 1.1777, 1.7959, 1.8955, 1.8906, 1.6738, 1.7510, + 1.4316, 1.8330, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629, + 1.8887, 1.5137, 1.4648, 1.6406, 1.7178, 2.2637, 1.5801, 2.1484, 2.0605, + 2.0098, 1.7539, 1.1768, 1.4375, 2.0723, 1.6162, 1.7832, 1.8291, 1.6064, + 1.5215, 1.9297, 2.3750, 2.1504, 1.7061, 1.1289, 1.4473, 1.5674, 1.6836, + 2.2930, 1.1221, 1.5547, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988, + 1.6348, 2.3379, 2.4414, 1.8857, 2.0020, 1.4834, 1.5488, 1.6514, 2.3711, + 1.9941, 2.3047, 1.4277, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502, + 2.1309, 1.2666, 1.1514, 1.9551, 1.8584, 1.9746, 1.5986, 1.9688, 2.1953, + 1.1514, 2.3262, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523, + 1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16) + + @parameterized.expand([False, True]) + def test_cuda_old(self, use_half2: bool): + + group_size = 128 + + # test the 256 kernel (in_features % 256 == 0 and out_features % 256 == 0) + m = 1 + k = 256 + n = 256 + device = "cuda" + + linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size) + + linear = linear_class( + bits=4, + group_size=group_size, + infeatures=k, + outfeatures=n, + bias=False, + ) + + torch.manual_seed(42) + + linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) + linear.scales = linear.scales + 0.002 + linear.use_cuda_fp16 = use_half2 + self.assertTrue(linear.autogptq_cuda_available) + + inp = torch.rand(1, m, k, dtype=torch.float16).to(device) + + linear = linear.eval() + linear = linear.to(device) + + with torch.no_grad(): + res = linear(inp)[0][0] + + if use_half2: + reference = self.REFERENCE_OLD_HALF.to(device) + else: + reference = self.REFERENCE_OLD_NO_HALF.to(device) + + self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))