Merge branch 'main' into MPT

# Conflicts:
#	auto_gptq/modeling/__init__.py
#	auto_gptq/modeling/_const.py
#	auto_gptq/modeling/auto.py
This commit is contained in:
LaaZa 2023-10-04 20:21:49 +03:00
commit 4b7389ddb7
71 changed files with 8719 additions and 361 deletions

View file

@ -5,12 +5,12 @@ on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
name: Build wheels for ${{ matrix.os }}
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA ${{ matrix.cuda }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
pyver: ["3.8", "3.9", "3.10"]
os: [ubuntu-20.04, windows-latest]
pyver: ["3.8", "3.9", "3.10", "3.11"]
cuda: ["11.7", "11.8"]
defaults:
run:
@ -20,20 +20,18 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
ref: 'main'
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.pyver }}
- name: Setup Mamba
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
mamba-version: "*"
use-mamba: true
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
@ -41,14 +39,19 @@ jobs:
- name: Install Dependencies
run: |
mamba install -y "pytorch[version=2.0.0,build=py*_cuda${env:CUDA_VERSION}*]" "pytorch-cuda=${env:CUDA_VERSION}" 'sentencepiece' 'cuda' 'ninja' -c 'pytorch' -c "nvidia/label/cuda-${env:CUDA_VERSION}.0" -c 'nvidia' -c 'conda-forge' -c 'defaults'
python -m pip install --upgrade build setuptools wheel
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
python -m pip install --upgrade build setuptools wheel ninja
- name: Build Wheel
run: |
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
# TODO: remove this
if (!$IsLinux) {$env:INCLUDE_EXLLAMA_KERNELS = 0}
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
python setup.py sdist bdist_wheel
@ -56,11 +59,11 @@ jobs:
- uses: actions/upload-artifact@v3
if: runner.os == 'Linux'
with:
name: 'linux-wheels'
name: 'linux-cuda-wheels'
path: ./dist/*.whl
- uses: actions/upload-artifact@v3
if: runner.os == 'Windows'
with:
name: 'windows-wheels'
name: 'windows-cuda-wheels'
path: ./dist/*.whl

74
.github/workflows/build_wheels_pypi.yml vendored Normal file
View file

@ -0,0 +1,74 @@
name: Build AutoGPTQ Wheels for PyPI with CUDA
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA 11.7
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, windows-latest]
pyver: ["3.8", "3.9", "3.10", "3.11"]
defaults:
run:
shell: pwsh
env:
CUDA_VERSION: "11.7"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.pyver }}
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Install Dependencies
run: |
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
python -m pip install --upgrade build setuptools wheel ninja
- name: Build Wheel
run: |
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
$env:PYPI_RELEASE = "1"
echo "CUDA_PATH:"
echo $env:CUDA_PATH
echo "PYPI_RELEASE:"
echo $env:PYPI_RELEASE
python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
if: runner.os == 'Linux'
with:
name: 'linux-cuda-wheels'
path: ./dist/*.whl
- uses: actions/upload-artifact@v3
if: runner.os == 'Windows'
with:
name: 'windows-cuda-wheels'
path: ./dist/*.whl

103
.github/workflows/build_wheels_rocm.yml vendored Normal file
View file

@ -0,0 +1,103 @@
name: Build AutoGPTQ Wheels with ROCm
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
strategy:
matrix:
os: [ubuntu-20.04]
python: ["3.8", "3.9", "3.10", "3.11"]
rocm: ["5.4.2"] # , "5.5", "5.6"]
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }}
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
- name: Free disk space
run: |
df -h
echo "Removing large packages"
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel
df -h
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get clean
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get autoclean -y >/dev/null 2>&1
df -h
echo "https://github.com/actions/virtual-environments/issues/709"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
df -h
echo "remove big /usr/local"
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
df -h
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
sudo rm -rf /usr/share/swift > /dev/null 2>&1
df -h
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.python }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Set up environment
run: |
echo "Using python:"
python --version
which python
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb
else
export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb
fi
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
sudo dpkg -i $ROCM_DL_FILE
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
python -m pip install --upgrade build setuptools wheel ninja
python -m pip install torch --index-url https://download.pytorch.org/whl/rocm${{ matrix.rocm }}
- name: Build wheels
run: |
echo "Using python for build:"
python --version
which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
with:
name: 'linux-rocm-wheels'
path: ./dist/*.whl

View file

@ -15,15 +15,22 @@
</p>
</h4>
*<center>📣 Long time no see! 👋 Architecture upgrade, performance optimization and more new features will come in July and August, stay tune! 🥂</center>*
## The path to v1.0.0
Hi, fellow community members, long time no see! I'm sorry that I haven't been able to update this project more frequently due to personal reasons during this period. The past few weeks have been huge in terms of my career plans. Not long ago, I officially bid farewell to the startup team that I joined for two years after graduation. I'm very grateful to the leaders and colleagues of the team for their trust and guidance, which enabled me to grow rapidly in two years; at the same time, I'm also really grateful to the team for allowing me to use the internal A100 GPU server cluster free of charge since the start of the AutoGPTQ project to complete various experiments and performance evaluations. (Of course, it can no longer be used in the future, so **it will mean a lot to me if there will be new hardware sponsorship!**) In the past two years, I have served as an AI engineer in this team, responsible for the LLM based dialogue system's architecture design and develop. We had successfully launched a product called gemsouls, but unfortunately it has ceased operations. Now, the team is about to launch a new product called [modelize](https://www.beta.modelize.ai/), which is **a LLM-native AI agent platform, where users can use multiple AI agents to build a highly automated team, allowing them to interact with each other in the workflow, collaborate to complete complex projects efficiently.**
Getting back to the topic, I'm very excited to see that in the past few months, research on optimizing the inference performance of LLMs has made tremendous progress. Now we can not only complete the inference of LLMs on high-end GPUs efficiently, but also on CPUs and even edge devices. A series of technological advancements make me eager to make more contributions to the open source community. Therefore, I will first use about four weeks to gradually update AutoGPTQ to the v1.0.0 official version. During this period, there will also be 2~3 minor versions are released to allow users to experience performance optimization and new features timely. In my vision, **by the time v1.0.0 is officially released, AutoGPTQ will be able to serve as an extendable and flexible quantization backend that supports all GPTQ-like methods and automatically quantize LLMs written by Pytorch**. I detailed the development plan in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/348), feel free to drop in there for discussion and give your suggestions!
## News or Update
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
- 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub.
- 2023-05-27 - (Update) - Support quantization and inference for `gpt_bigcode`, `codegen` and `RefineWeb/RefineWebModel`(falcon) model types.
- 2023-05-04 - (Update) - Support using faster cuda kernel when `not desc_act or group_size == -1`.
*For more histories please turn to [here](docs/NEWS_OR_UPDATE.md)*
@ -52,36 +59,17 @@ For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200
## Installation
### Quick Installation
You can install the latest stable release of AutoGPTQ from pip:
```shell
pip install auto-gptq
```
Start from v0.2.0, you can download pre-build wheel that satisfied your environment setup from each version's release assets and install it to skip building stage for the fastest installation speed. For example:
```shell
# firstly, cd the directory where the wheel saved, then execute command below
pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # install v0.2.0 auto_gptq pre-build wheel for linux in an environment whose python=3.10 and cuda=11.8
```
#### disable cuda extensions
By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using:
```shell
BUILD_CUDA_EXT=0 pip install auto-gptq
```
And to make sure `autogptq_cuda` is not ever in your virtual environment, run:
```shell
pip uninstall autogptq_cuda -y
```
You can install the latest stable release of AutoGPTQ from pip with pre-built wheels compatible with PyTorch 2.0.1:
#### to support triton speedup
To integrate with `triton`, using:
> warning: currently triton only supports linux; 3-bit quantization is not supported when using triton
* For CUDA 11.7: `pip install auto-gptq`
* For CUDA 11.8: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
* For RoCm 5.4.2: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
```shell
pip install auto-gptq[triton]
```
**Warning:** These wheels are not expected to work on PyTorch nightly. Please install AutoGPTQ from source when using PyTorch nightly.
AutoGPTQ can be installed with the Triton dependency with `pip install auto-gptq[triton]` in order to be able to use the Triton backend (currently only supports linux, no 3-bits quantization).
### Install from source
<details>
<summary>click to see details</summary>
Clone the source code:
```shell
@ -89,13 +77,17 @@ git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
```
Then, install from source:
```shell
pip install .
pip install -v .
```
Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch extension building.
You can set `BUILD_CUDA_EXT=0` to disable pytorch extension building, but this is **strongly discouraged** as AutoGPTQ then falls back on a slow python implementation.
Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
</details>
```
ROCM_VERSION=5.6 pip install -v .
```
For RoCm systems, the packages `rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev` are required to build.
## Quick Tour
@ -315,6 +307,14 @@ print(
## Supported Evaluation Tasks
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
## Running tests
Tests can be run with:
```
pytest tests/ -s
```
## Acknowledgement
- Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
- Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).

View file

@ -15,15 +15,21 @@
</p>
</h4>
*<center>📣 好久不见!👋 七月和八月将会迎来架构升级,性能优化和新特性,敬请关注!🥂</center>*
## 通向 v1.0.0 之路
嗨,社区的伙伴们,好久不见!很抱歉这段时间由于个人原因,我没能以较高的频率来更新这个项目。过去几周对我的职业生涯规划而言意义重大。在不久前,我正式告别了毕业后便加入两年之久的创业团队,非常感谢团队的领导和同事们给予我的信任与指导,让我能够在两年时间里飞速地成长;同时也十分感激团队允许我自 AutoGPTQ 项目创立以来一直无偿使用内部的 A100 GPU 服务器集群以完成各项实验与性能测评。(当然今后是无法继续使用了,因此**若有新的硬件赞助我将感激不尽**!)过去的两年里,我在这个团队中担任算法工程师的角色,负责基于大语言模型的对话系统架构设计与开发,我们曾成功推出一款名为 gemsouls 的产品,但不幸的是它已经停止运营。而现在,这个团队即将推出一款名为 [modelize](https://www.beta.modelize.ai/) 的新产品,**这是一个大模型原生的 AI 智能体平台,用户可以使用多个 AI 智能体搭建一个高度自动化的团队,让它们在工作流中相互合作,高效完成复杂的项目。**
话归正题,我非常兴奋地看到,在过去几个月的时间里,针对大语言模型推理性能优化的研究取得了巨大的进展,如今我们不仅能够在高端显卡上完成大语言模型的推理,甚至在 CPU 和边缘设备上都可以轻松运行大语言模型。一系列的技术进步,让我同样迫不及待地在开源社区上做出更多的贡献,因此,首先,我将用约四周的时间将 AutoGPTQ 迭代至 v1.0.0 正式版本,在此期间,也会有 2~3 个小版本发布以让用户能够及时体验性能优化和新特性。在我的愿景里,**到 v1.0.0 版本正式发布时AutoGPTQ 将能够作为一个灵活可拓展的、支持所有 GPTQ-like 方法的量化后端,自动地完成各种基于 Pytorch 编写的大语言模型的量化工作**。我在[这里](https://github.com/PanQiWei/AutoGPTQ/issues/348)详细介绍了开发计划,欢迎移步至此进行讨论并给出你们的建议!
## 新闻或更新
- 2023-08-23 - (新闻) - 🤗 Transformers、optimum 和 peft 完成了对 `auto-gptq` 的集成,现在使用 GPTQ 模型进行推理和训练将变得更容易!阅读 [这篇博客](https://huggingface.co/blog/gptq-integration) 和相关资源以了解更多细节!
- 2023-08-21 - (新闻) - 通义千问团队发布了基于 `auto-gptq` 的 Qwen-7B 4bit 量化版本模型,并提供了[详尽的测评结果](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (更新) - 支持 exllama 的 q4 CUDA 算子使得 int4 量化模型能够获得至少1.3倍的推理速度提升.
- 2023-08-04 - (更新) - 支持 RoCm 使得 AMD GPU 的用户能够使用 auto-gptq 的 CUDA 拓展.
- 2023-07-26 - (更新) - 一个优雅的 [PPL 测评脚本](examples/benchmark/perplexity.py)以获得可以与诸如 `llama.cpp` 等代码库进行公平比较的结果。
- 2023-06-05 - (更新) - 集成 🤗 peft 来使用 gptq 量化过的模型训练适应层,支持 LoRAAdaLoRAAdaptionPrompt 等。
- 2023-05-30 - (更新) - 支持从 🤗 Hub 下载量化好的模型或上次量化好的模型到 🤗 Hub。
- 2023-05-27 - (更新) - 支持以下模型的量化和推理: `gpt_bigcode` `codegen` 以及 `RefineWeb/RefineWebModel`falcon
- 2023-05-04 - (更新) - 支持在 `not desc_act or group_size == -1` 的情况下使用更快的 cuda 算子。
*获取更多的历史信息,请转至[这里](docs/NEWS_OR_UPDATE.md)*
@ -52,15 +58,14 @@
## 安装
### 快速安装
你可以通过 pip 来安装 AutoGPTQ 当前最新的稳定版本:
```shell
pip install auto-gptq
```
从 0.2.0 版本开始,你可以从每次版本发布的资产文件列表中下载预构建好的符合你系统配置情况的轮子文件,并通过安装这些轮子文件来跳过漫长的构建过程以达到最快的安装速度。如下是一个例子:
```shell
# 首先,进入轮子文件存放的目录,然后执行下面的命令
pip install auto_gptq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl # 在 linux 操作系统的一个 python=3.10 且 cuda=11.8 的环境下安装 0.2.0 版本的 auto_gptq
```
你可以通过 pip 来安装与 PyTorch 2.0.1 相兼容的最新稳定版本的 AutoGPTQ 的预构建轮子文件:
* 对于 CUDA 11.7 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/`
* 对于 CUDA 11.8 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
* 对于 RoCm 5.4.2 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
**警告:** 预构建的轮子文件不一定在 PyTorch 的 nightly 版本上有效。如果要使用 PyTorch 的 nightly 版本,请从源码安装 AutoGPTQ。
#### 取消 cuda 拓展的安装
默认情况下,在 `torch``cuda` 已经于你的机器上被安装时cuda 拓展将被自动安装,如果你不想要这些拓展的话,采用以下安装命令:
```shell
@ -95,6 +100,14 @@ pip install .
如果你想要使用 triton 加速且其能够被你的操作系统所支持,请使用 `.[triton]`
对应 AMD GPUs为了从源码安装以支持 RoCm请设置 `ROCM_VERSION` 环境变量。同时通过设置 `PYTORCH_ROCM_ARCH` ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)) 可提升编译速度,例如:对于 MI200 系列设备,该变量可设为 `gfx90a`。例子:
```
ROCM_VERSION=5.6 pip install .
```
对于 RoCm 系统,在从源码安装时额外需要提前安装以下包:`rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev`
</details>
## 快速开始

View file

@ -1,4 +1,5 @@
__version__ = "0.3.2"
__version__ = "0.5.0.dev0"
from .modeling import BaseQuantizeConfig
from .modeling import AutoGPTQForCausalLM
from .utils.peft_utils import get_gptq_peft_model
from .utils.exllama_utils import exllama_set_max_input_length

View file

@ -12,4 +12,5 @@ from .gpt_bigcode import *
from .codegen import *
from .baichuan import *
from .internlm import *
from .qwen import *
from .mpt import *

View file

@ -13,6 +13,7 @@ import torch.nn as nn
import transformers
from accelerate.hooks import remove_hook_from_module
from safetensors.torch import save_file as safe_save
from safetensors.torch import load_file as safe_load
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers
@ -24,7 +25,9 @@ from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ..utils.import_utils import dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE
from ..utils.import_utils import (
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE, EXLLAMAV2_KERNELS_AVAILABLE
)
logger = getLogger(__name__)
@ -35,6 +38,7 @@ class BaseQuantizeConfig(PushToHubMixin):
group_size: int = field(default=-1)
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
static_groups: bool = field(default=False)
sym: bool = field(default=True)
true_sequential: bool = field(default=True)
model_name_or_path: Optional[str] = field(default=None)
@ -87,8 +91,16 @@ class BaseQuantizeConfig(PushToHubMixin):
_commit_hash=commit_hash,
)
field_names = [field.name for field in fields(cls)]
with open(resolved_config_file, "r", encoding="utf-8") as f:
return cls(**json.load(f))
args_from_json = json.load(f)
filtered_args = {}
for key, val in args_from_json.items():
if key in field_names:
filtered_args[key] = val
else:
logger.warning(f"ignoring unknown parameter in {quantize_config_filename}: {key}.")
return cls(**filtered_args)
def to_dict(self):
return {
@ -96,6 +108,7 @@ class BaseQuantizeConfig(PushToHubMixin):
"group_size": self.group_size,
"damp_percent": self.damp_percent,
"desc_act": self.desc_act,
"static_groups": self.static_groups,
"sym": self.sym,
"true_sequential": self.true_sequential,
"model_name_or_path": self.model_name_or_path,
@ -361,7 +374,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
scale, zero, g_idx = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
group_size=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act
actorder=self.quantize_config.desc_act,
static_groups=self.quantize_config.static_groups
)
quantizers[f'{self.layers_block_name}.{i}.{name}'] = (
gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device),
@ -427,7 +441,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
return torch.device(device)
def to(self, device: Union[str, torch.device]):
return self.model.to(device)
self.model.to(device)
return self
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
@ -682,7 +697,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
torch_dtype: torch.dtype = torch.float16,
use_qigen: bool = False,
torch_dtype: Optional[torch.dtype] = None,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
use_cuda_fp16: bool = True,
@ -692,6 +708,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
**kwargs
):
"""load quantized model from local disk"""
@ -719,10 +737,52 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
if use_qigen and not QIGEN_AVAILABLE:
logger.warning("Qigen is not installed, reset use_qigen to False.")
use_qigen = False
if use_triton and not TRITON_AVAILABLE:
logger.warning("triton is not installed, reset use_triton to False")
logger.warning("Triton is not installed, reset use_triton to False.")
use_triton = False
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
logger.warning(
"Exllama kernel is not installed, reset disable_exllama to True. "
"This may because you installed auto_gptq using a pre-build wheel "
"on Windows, in which exllama_kernels are not compiled. To use "
"exllama_kernels to further speedup inference, you can re-install "
"auto_gptq from source."
)
disable_exllama = True
if not disable_exllamav2 and not EXLLAMAV2_KERNELS_AVAILABLE:
logger.warning(
"Exllamav2 kernel is not installed, reset disable_exllamav2 to True. "
"This may because you installed auto_gptq using a pre-build wheel "
"on Windows, in which exllama_kernels are not compiled. To use "
"exllama_kernels to further speedup inference, you can re-install "
"auto_gptq from source."
)
disable_exllamav2 = True
if not AUTOGPTQ_CUDA_AVAILABLE:
logger.warning(
"CUDA kernels for auto_gptq are not installed, this will result in "
"very slow inference speed. This may because:\n"
"1. You disabled CUDA extensions compilation by setting BUILD_CUDA_EXT=0 when install auto_gptq from source.\n"
"2. You are using pytorch without CUDA support.\n"
"3. CUDA and nvcc are not installed in your device."
)
if use_qigen and QIGEN_AVAILABLE:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
disable_exllamav2 = True
if not disable_exllamav2 and not disable_exllama:
logger.warning(
"You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False"
)
disable_exllama = True
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
@ -731,7 +791,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
raise TypeError(f"{config.model_type} isn't supported yet.")
if quantize_config is None:
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **kwargs)
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
if model_basename is None:
if quantize_config.model_file_base_name:
@ -769,13 +829,25 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
model_save_name = resolved_archive_file
if not use_triton and trainable:
if (not disable_exllama or not disable_exllamav2) and trainable:
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
disable_exllama = True
disable_exllamav2 = True
elif not use_triton and trainable:
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs):
pass
if torch_dtype is None:
if not use_qigen:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if not use_qigen:
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
@ -806,6 +878,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantize_config.bits,
quantize_config.group_size,
use_triton=use_triton,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable
@ -842,7 +916,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
)
if low_cpu_mem_usage:
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits)
accelerate.utils.modeling.load_checkpoint_in_model(
model,
@ -852,7 +926,47 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
offload_buffers=True
)
model = simple_dispatch_model(model, device_map)
else:
if quantize_config.desc_act:
NotImplementedError('desc_act=True is not yet supported.')
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype
)
layers = find_layers(model)
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
for name in list(layers.keys()):
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
del layers[name]
if model_save_name.endswith('.safetensors'):
checkpoint = safe_load(model_save_name)
else:
checkpoint = torch.load(model_save_name)
make_quant(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
use_triton=use_triton,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
use_qigen=True
)
preprocess_checkpoint_qigen(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
checkpoint
)
model.load_state_dict(checkpoint)
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
@ -877,7 +991,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable
trainable=trainable,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2
)
if inject_fused_mlp:
if cls.fused_mlp_module_type is None:
@ -889,6 +1006,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
use_triton=use_triton
)
# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
model.eval()
# == step6: (optional) warmup triton == #
if use_triton and warmup_triton:
@ -900,7 +1020,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
# == step7: make model compatible with peft
cls.make_sure_compatible_with_peft(
model, use_triton, quantize_config.desc_act, quantize_config.group_size
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits
)
return cls(
@ -937,10 +1057,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.enable_trainable_mode(enabled=False)
@staticmethod
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int):
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
GeneralQuantLinear.inject_to_model(
model,
dynamically_import_QuantLinear(use_triton, desc_act, group_size)
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
)
def __getattr__(self, item):
@ -949,5 +1069,4 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
except:
return getattr(self.model, item)
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]

View file

@ -20,9 +20,14 @@ SUPPORTED_MODELS = [
"RefinedWeb",
"baichuan",
"internlm",
"qwen",
"mpt",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon")
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]

View file

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Union
from typing import Union, Optional
import accelerate
import torch
@ -7,10 +7,9 @@ import torch.nn as nn
from transformers import AutoConfig
import transformers
from ._const import SUPPORTED_MODELS, CPU, CUDA_0
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear
logger = getLogger(__name__)
@ -56,12 +55,15 @@ def make_quant(
bits,
group_size,
name='',
use_triton=False,
use_cuda_fp16=True,
desc_act=False,
trainable=False
use_triton: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
use_qigen: bool = False,
use_cuda_fp16: bool = True,
desc_act: bool = False,
trainable: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen)
if isinstance(module, QuantLinear):
return
@ -80,7 +82,7 @@ def make_quant(
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
in_features = tmp.weight.shape[0]
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton:
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
)
@ -98,9 +100,82 @@ def make_quant(
use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
trainable=trainable
trainable=trainable,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen
)
def preprocess_checkpoint_qigen(
module,
names,
bits,
group_size,
checkpoint,
name='',
):
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
if isinstance(module, QuantLinear):
in_features = module.infeatures
out_features = module.outfeatures
zeros = checkpoint[name + '.qzeros']
scales = checkpoint[name + '.scales'].float()
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
del checkpoint[name + '.qzeros']
del checkpoint[name + '.g_idx']
if name + '.bias' in checkpoint:
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
else:
checkpoint[name + '.bias'] = torch.zeros(out_features)
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
if bits == 4:
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
elif bits == 3:
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
elif bits == 2:
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
checkpoint[name + '.qweight'] = qweight
return
for name1, child in module.named_children():
preprocess_checkpoint_qigen(
child,
names,
bits,
group_size,
checkpoint,
name + '.' + name1 if name != '' else name1,
)
def pack_model(
model,
@ -113,7 +188,7 @@ def pack_model(
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=False, disable_exllamav2=True)
if force_layer_back_to_cpu:
model.to(CPU)
@ -121,7 +196,7 @@ def pack_model(
logger.info('Packing model...')
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act)
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act, disable_exllama=False, disable_exllamav2=True)
qlayers = find_layers(model, [QuantLinear])
for name in qlayers:
logger.info(name)
@ -185,8 +260,118 @@ def simple_dispatch_model(model, device_map):
return model
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size)
def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
"""
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""
device_to_buffers_size = {}
model_uses_exllama = False
for name, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
model_uses_exllama = True
device = submodule.qweight.device
if device not in device_to_buffers_size:
device_to_buffers_size[device] = {
"max_dq_buffer_size": 1,
"max_inner_outer_dim": 1
}
if not use_act_order:
submodule._use_act_order = False
else:
submodule._use_act_order = True
# Disable this heuristic for detecting act_order, but it could be used instead of the config.
"""
if submodule.g_idx is None:
submodule.act_order = False
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
submodule.g_idx = None
submodule.act_order = False
else:
submodule.act_order = True
"""
device_to_buffers_size[device]["max_dq_buffer_size"] = max(device_to_buffers_size[device]["max_dq_buffer_size"], submodule.qweight.numel() * 8)
if use_act_order:
device_to_buffers_size[device]["max_inner_outer_dim"] = max(device_to_buffers_size[device]["max_inner_outer_dim"], submodule.infeatures, submodule.outfeatures)
if model_uses_exllama:
# To be honest this is quite ugly, not proud of this.
from exllama_kernels import prepare_buffers, set_tuning_params
device_to_buffers = {}
if use_act_order:
if max_input_length is None:
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
else:
max_input_len = max_input_length
else:
if max_input_length is not None:
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")
max_input_len = 1
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_len, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
for device, buffers in model.device_to_buffers.items():
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
# The buffers need to have been initialized first before calling make_q4.
for name, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
submodule.post_init()
## exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device,0))
if model_uses_exllamav2:
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
device_tensors = {}
for device, scratch_bytes in fixed_bytes.items():
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
# have persistent buffers, otherwise we will get OOM
model.device_tensors = device_tensors
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
device = submodule.qweight.device
submodule.post_init(temp_dq = model.device_tensors[device])
torch.cuda.empty_cache()
return model
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size, bits: int):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer('bias', torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
@ -199,7 +384,9 @@ __all__ = [
"get_module_by_name_prefix",
"get_module_by_name_suffix",
"make_quant",
"preprocess_checkpoint_qigen",
"pack_model",
"autogptq_post_init",
"check_and_get_model_type",
"simple_dispatch_model",
"make_sure_no_tensor_in_meta_device"

View file

@ -15,6 +15,7 @@ from .rw import RWGPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
@ -30,8 +31,10 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"codegen": CodeGenGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb": RWGPTQForCausalLM,
"falcon": RWGPTQForCausalLM,
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
}
@ -82,6 +85,8 @@ class AutoGPTQForCausalLM:
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
@ -121,6 +126,8 @@ class AutoGPTQForCausalLM:
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
**keywords
)

View file

@ -0,0 +1,16 @@
from ._base import *
class QwenGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "QWenBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f", "transformer.visual"]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],
["mlp.w1", "mlp.w2"],
["mlp.c_proj"]
]
__all__ = ["QwenGPTQForCausalLM"]

View file

@ -234,10 +234,13 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
@ -252,7 +255,16 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (

View file

@ -79,7 +79,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
past_key_value = (key_states, value_states) if use_cache else None
if compare_pytorch_version("v2.0.0", op="eq"):
if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
@ -134,12 +134,15 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
@ -152,7 +155,18 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (

View file

@ -8,7 +8,6 @@ class GeneralQuantLinear(nn.Linear):
out_features=quant_linear_module.outfeatures,
bias=True
)
self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures
self.bits = quant_linear_module.bits
@ -18,15 +17,15 @@ class GeneralQuantLinear(nn.Linear):
self.weight.requires_grad = False
self.weight.data = quant_linear_module.qweight
self.qweight = self.weight
self.register_buffer('qweight', quant_linear_module.qweight)
self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False
self.bias.requires_grad = False
self.qzeros = quant_linear_module.qzeros
self.scales = quant_linear_module.scales
self.g_idx = quant_linear_module.g_idx
self.register_buffer('qzeros', quant_linear_module.qzeros)
self.register_buffer('scales', quant_linear_module.scales)
self.register_buffer('g_idx', quant_linear_module.g_idx)
if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf

View file

@ -20,6 +20,8 @@ except ImportError:
class QuantLinear(nn.Module):
QUANT_TYPE = "cuda"
def __init__(
self,
bits,
@ -88,6 +90,9 @@ class QuantLinear(nn.Module):
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):

View file

@ -19,6 +19,8 @@ except ImportError:
class QuantLinear(nn.Module):
QUANT_TYPE = "cuda-old"
def __init__(
self,
bits,
@ -90,6 +92,9 @@ class QuantLinear(nn.Module):
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):

View file

@ -0,0 +1,171 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllama
from logging import getLogger
import torch
import torch.nn as nn
import math
import numpy as np
import transformers
logger = getLogger(__name__)
try:
from exllama_kernels import make_q4, q4_matmul
except ImportError:
logger.error('exllama_kernels not installed.')
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllama"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllama kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
if trainable:
raise NotImplementedError("Exllama kernel does not support training.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.trainable = trainable
self.maxq = 2 ** self.bits - 1
assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.width = self.qweight.shape[1]
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx.to("cpu") if self._use_act_order else None,
self.qweight.device.index
)
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out = ext_q4_matmul(x.half(), self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View file

@ -0,0 +1,188 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from logging import getLogger
import torch
import torch.nn as nn
import math
logger = getLogger(__name__)
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError:
logger.error('exllamav2_kernels not installed.')
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def _torch_device(idx):
if idx == -1: return "cpu"
return f"cuda:{idx}"
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
"""
Create Q matrix
"""
# EXL2
# won't work as the moment because the tensors are not the same.
if "q_weight" in w:
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
return make_q_matrix(w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
temp_dq)
# GPTQ
elif "qweight" in w:
if w["scales"].dtype == torch.float:
w["scales"] = w["scales"].half()
# GPTQ with g_idx (act_order)
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
temp_dq)
# GPTQ without g_idx
else:
return make_q_matrix(w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,
temp_dq)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
if trainable:
raise NotImplementedError("Exllamav2 kernel does not support training.")
self.q_handle = None
self.q_tensors = None
self.padding = - outfeatures % 32
self.infeatures = infeatures
self.outfeatures = outfeatures + self.padding
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.trainable = trainable
self.maxq = 2 ** self.bits - 1
assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
"qweight":self.qweight,
"qzeros":self.qzeros,
"scales":self.scales,
"g_idx":self.g_idx
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq
)
def forward(self, x, force_cuda = False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
scratch_bytes: int
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device_idx, scratch_bytes):
self.device_idx = device_idx
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
def get_scratch_slice(self, size_bytes):
if self.scratch is None: self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
scratch_slice = self.scratch.narrow(0, 0, size_half)
return scratch_slice

View file

@ -0,0 +1,262 @@
from copy import deepcopy
import torch
from torch import nn
from tqdm import tqdm
import gc
import math
import numpy as np
from gekko import GEKKO
from logging import getLogger
logger = getLogger(__name__)
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model
#cinfergen if bits==3:
# tu = tu*3
B = m.Const(value=bits)
TP = m.Const(value=T//p)
k = m.Var(1,integer=True,lb=1)
z = m.Var(1,integer=True,lb=1)
w = m.Var(1,integer=True,lb=1)
y = m.Var(1,integer=True,lb=1)
v = m.Var(1,integer=True,lb=1)
mb = m.Var(mu,integer=True,lb=1)
if gs != -1:
gg = m.Var(1,integer=True,lb=1)
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
L = m.Var(integer=True,lb=0,ub=l1)
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
m.Equation(mb * k == M)
if gs != -1:
m.Equation(gs * gg == mb)
# m.Equation(tb * z == T)
m.Equation(tb * z == TP)
m.Equation(mu * w == mb)
m.Equation(tu * y == tb)
# m.Equation(tb * v == tt)
m.Maximize(L)
m.options.SOLVER = 1
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 2', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
try:
m.solve(disp=False)
except:
try:
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 1', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
m.solve(disp=False)
except:
# mytb = T//p
mytb = tu
if gs != -1:
mymb = gs
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
mymb += gs
while M % mymb != 0:
mymb -= gs
return (int(mymb), int(mytb))
else:
mymb = mu
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
mymb += mu
while M % mymb != 0:
mymb -= mu
return (int(mymb), int(mytb))
return (int(mb.value[0]), int(tb.value[0]))
params = {}
def compute_reductions(x, gs=-1, cpp=True):
if cpp:
if len(x.shape) != 1:
rows, cols = x.shape
else:
rows = 1
cols = x.shape[0]
if gs == -1:
out = torch.zeros(rows).float().contiguous()
mygs = cols
else:
out = torch.zeros(rows, cols // gs).float().contiguous()
mygs = gs
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
return out
if gs == -1:
if len(x.shape) != 1:
return torch.sum(x,1)
else:
return torch.sum(x)
else:
if len(x.shape) != 1:
rows, cols = x.shape
out = torch.zeros(rows, cols // gs).float().contiguous()
for i in range(cols // gs):
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
return out
else:
cols = x.shape[0]
out = torch.zeros(cols // gs).float().contiguous()
for i in range(cols // gs):
out[i] = torch.sum(x[i*gs:(i+1)*gs])
return out
def process_zeros_scales(zeros, scales, bits, M):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != M:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != M:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
class QuantLinear(nn.Module):
QUANT_TYPE = "qigen"
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
super().__init__()
if bits not in [2, 4]:
raise NotImplementedError("Only 2,4 bits are supported.")
if trainable:
raise NotImplementedError("Qigen kernel does not support training.")
self.bits = bits
pack = 32 // bits
self.infeatures = infeatures
self.outfeatures = outfeatures
n = hint
m = self.infeatures
t = self.outfeatures
#registers for now are fixed
if bits == 3:
packed = 32
unroll = 3
nu = 1 #args.n
mu = 32
tu = 32
else:
packed = 32 // bits
unroll = 2
nu = 1 #args.n
mu = 16
tu = 32
nb = n # it's always small for transformers
global params
if (m,t) in params:
mb = params[(m,t)][0]
tb = params[(m,t)][1]
else:
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
params[(m,t)] = (mb,tb)
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert(np.sum(split) == t)
split = split.astype(int)
self.tt = int(split[0])
if split[0] == split[-1]:
self.cutoff = int(p+1)
else:
self.cutoff = int(idx + 1)
self.mb = mb #// packed
self.tb = tb
self.group_size = group_size
self.register_buffer('bias', torch.zeros(self.outfeatures))
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
if bits == 4:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
elif bits == 3:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
elif bits == 2:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
B = x.shape[0]
new_x = x.T.contiguous()
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
if self.group_size == -1:
if self.bits == 4:
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 2:
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 3:
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
else:
if self.bits == 4:
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 2:
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 3:
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
return out.reshape(out_shape)

View file

@ -21,6 +21,8 @@ except ImportError:
class QuantLinear(nn.Module, TritonModuleMixin):
QUANT_TYPE = "triton"
def __init__(
self,
bits,
@ -64,6 +66,9 @@ class QuantLinear(nn.Module, TritonModuleMixin):
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):

View file

@ -60,7 +60,7 @@ class GPTQ:
self.H += inp.matmul(inp.t())
def fasterquant(
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
@ -80,10 +80,26 @@ class GPTQ:
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
@ -96,11 +112,6 @@ class GPTQ:
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
@ -116,6 +127,7 @@ class GPTQ:
d = Hinv1[i, i]
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
@ -123,6 +135,11 @@ class GPTQ:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
@ -148,10 +165,12 @@ class GPTQ:
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]

View file

@ -0,0 +1,48 @@
import gc
import torch
def exllama_set_max_input_length(model, max_input_length: int):
"""
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
"""
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
if not model.quantize_config.desc_act:
raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
device_to_buffers_size = {}
for device, buffers in model.device_to_buffers.items():
device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
for key in list(model.device_to_buffers.keys()):
del model.device_to_buffers[key]
model.device_to_buffers = None
del model.device_to_buffers
gc.collect()
torch.cuda.empty_cache()
cleanup_buffers_cuda()
device_to_buffers = {}
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
return model

View file

@ -1,4 +1,6 @@
from packaging.version import parse as parse_version
from logging import getLogger
import torch
try:
import triton
@ -8,18 +10,53 @@ except ImportError:
TRITON_AVAILABLE = False
try:
import autogptq_cuda
import autogptq_cuda_256
import autogptq_cuda_64
AUTOGPTQ_CUDA_AVAILABLE = True
except:
AUTOGPTQ_CUDA_AVAILABLE = False
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int):
try:
import exllama_kernels
EXLLAMA_KERNELS_AVAILABLE = True
except:
EXLLAMA_KERNELS_AVAILABLE = False
try:
import exllamav2_kernels
EXLLAMAV2_KERNELS_AVAILABLE = True
except:
EXLLAMAV2_KERNELS_AVAILABLE = False
try:
import cQIGen as qinfer
QIGEN_AVAILABLE = True
except:
QIGEN_AVAILABLE = False
logger = getLogger(__name__)
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
if use_qigen:
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
else:
if use_triton:
if torch.version.hip:
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
else:
if not desc_act or group_size == -1:
if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
elif not desc_act or group_size == -1:
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
else:
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear

View file

@ -402,7 +402,7 @@ def get_gptq_peft_model(
with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config)
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:

View file

@ -30,8 +30,9 @@
// }
// #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
@ -1160,7 +1161,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
res2 = {};
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
@ -1172,7 +1173,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
i += width;
k += 8;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);
@ -1294,7 +1295,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
}
res2 = {};
std::memset(&res2, 0, sizeof(half2));
tmp1 = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
@ -1324,7 +1325,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width;
k += 5;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);
@ -1409,10 +1410,15 @@ __global__ void VecQuant4MatMulKernelFaster_old(
while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
res2 = {};
//std::memset(&res2, 0, sizeof(half2));
//res2 = __float2half2_rn((float)0.);
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
@ -1420,7 +1426,9 @@ __global__ void VecQuant4MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width;
k += 4;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);

View file

@ -30,7 +30,8 @@
// }
// #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
@ -1161,7 +1162,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
res2 = {};
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
@ -1173,7 +1174,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
i += width;
k += 8;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);
@ -1295,7 +1296,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
}
res2 = {};
std::memset(&res2, 0, sizeof(half2));
tmp1 = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
@ -1325,7 +1326,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width;
k += 5;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);
@ -1413,7 +1414,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
res2 = {};
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
@ -1421,7 +1422,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width;
k += 4;
res += __half2float(res2.x) + __half2float(res2.y);
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);

View file

@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View file

@ -0,0 +1,75 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View file

@ -0,0 +1,55 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View file

@ -0,0 +1,63 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "../util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View file

@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View file

@ -0,0 +1,260 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}

View file

@ -0,0 +1,43 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "../tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View file

@ -0,0 +1,225 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
#include "../matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
height / 8,
1
);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View file

@ -0,0 +1,53 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View file

@ -0,0 +1,255 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
}

View file

@ -0,0 +1,49 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View file

@ -0,0 +1,294 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View file

@ -0,0 +1,13 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View file

@ -0,0 +1,33 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View file

@ -0,0 +1,13 @@
#ifndef _config_h
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define QMODE_2BIT 1
#define QMODE_3BIT 1
#define QMODE_4BIT 1
#define QMODE_5BIT 1
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#endif

View file

@ -0,0 +1,12 @@
#ifndef _util_h
#define _util_h
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#endif

View file

@ -0,0 +1,56 @@
#ifndef _compat_cuh
#define _compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View file

@ -0,0 +1,121 @@
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "quant/qdq_util.cuh"
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
#endif

View file

@ -0,0 +1,238 @@
#include "q_gemm.cuh"
#include "util.cuh"
#include "matrix_view.cuh"
#include "../config.h"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
void gemm_half_q_half_cuda_part
(
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
bool clear
)
{
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
kernel<<<gridDim, blockDim>>>
(
a,
b->cuda_q_weight,
b->cuda_q_scale,
b->cuda_q_scale_max,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_8,
b->rows_6,
b->rows_5,
b->rows_4,
b->rows_3,
b->rows_2,
clear
);
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGI(b->rows_4);
// DBGI(b->height);
kernel<<<gridDim, blockDim>>>
(
a,
b->cuda_q_weight,
b->cuda_gptq_qzeros,
b->cuda_gptq_scales,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_4,
clear
);
}
}
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear,
half* temp_dq,
bool force_cuda
)
{
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq;
b->reconstruct(temp_dq);
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
const half alpha = __float2half(1.0f);
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
cublasHgemm(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
size_n, size_m, size_k,
&alpha, temp_dq, size_n,
a, size_k,
&beta, c, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasGemmEx(cublas_handle,
// CUBLAS_OP_N, CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n,
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
}
else
{
//printf("cuda\n");
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
}
}
}
__global__ void clear_kernel
(
half* __restrict__ c,
const int size_m,
const int size_n
)
{
int m = blockIdx.y;
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
if (n >= size_n) return;
int4* c_ptr = (int4*)(c + m * size_n + n);
*c_ptr = {};
}
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
)
{
return;
dim3 blockDim, gridDim;
blockDim.x = CLEAR_N_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.y = size_m;
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
}

View file

@ -0,0 +1,33 @@
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q_matrix.cuh"
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
);
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
);
#endif

View file

@ -0,0 +1,484 @@
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const int,
const int,
const int,
const int,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
block_a_ptr[t] = a0;
}
}
// Clear
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
qscales[0]++;
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
}
// a, b offset
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
// Column result
float block_c[m_count][4] = {};
// Dequantize groups
int k = offset_k;
while (k < rows_8 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[5];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
// Accumulate column sums in c
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
return NULL;
}

View file

@ -0,0 +1,219 @@
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hadd2(result, g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
// __syncthreads();
// Column result
float block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}

View file

@ -0,0 +1,603 @@
#include "q_matrix.cuh"
#include "matrix_view.cuh"
#include "util.cuh"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
// QMatrix constructor
QMatrix::QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
) :
device(_device),
height(_height),
width(_width),
groups(_groups),
temp_dq(_temp_dq)
{
cudaSetDevice(device);
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
// Create group map
rows_8 = 0;
rows_6 = 0;
rows_5 = 0;
rows_4 = 0;
rows_3 = 0;
rows_2 = 0;
if (!is_gptq)
{
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
}
free(cpu_q_groups);
rows_6 += rows_8;
rows_5 += rows_6;
rows_4 += rows_5;
rows_3 += rows_4;
rows_2 += rows_3;
}
else
{
rows_4 = height;
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
}
// Shuffle quantized data
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
// Reconstruct b[k,n] (GPTQ)
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_4
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ uint16_t perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
// Reconstruct b[k,n]
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
// Preload remapping table
int t = threadIdx.x;
__shared__ uint16_t perm[BLOCK_KN_SIZE];
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
// Column
int n = offset_n + t;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
int lk = 0;
__syncthreads();
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
dequant_8bit_8(q_0, q_1, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
uint32_t q_3 = *b_ptr; b_ptr += size_n;
uint32_t q_4 = *b_ptr; b_ptr += size_n;
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_4bit_8(q_0, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_2bit_16(q_0, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
}
void QMatrix::reconstruct(half* out)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
reconstruct_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
height,
width,
groupsize,
groups,
out,
rows_8,
rows_6,
rows_5,
rows_4,
rows_3,
rows_2
);
}
else
{
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_gptq_qzeros,
cuda_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
groups,
out,
rows_4
);
}
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint16_t* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Reduce to uint16_t
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
// Move to CUDA
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_new_qweight,
cuda_q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}

View file

@ -0,0 +1,71 @@
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class QMatrix
{
public:
int device;
bool is_gptq;
int height;
int width;
int groups;
int groupsize;
int rows_8;
int rows_6;
int rows_5;
int rows_4;
int rows_3;
int rows_2;
uint32_t* cuda_q_weight = NULL;
uint16_t* cuda_q_perm = NULL;
uint16_t* cuda_q_invperm = NULL;
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
half* temp_dq;
QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
);
~QMatrix();
void reconstruct(half* out);
void make_sequential(const uint32_t* cpu_g_idx);
private:
};
#endif

View file

@ -0,0 +1,103 @@
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z4 = __halves2half2(z4_, z4_);
const half2 z16 = __halves2half2(z16_, z16_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
#else
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,169 @@
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,227 @@
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
#endif
#endif

View file

@ -0,0 +1,207 @@
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
uint32_t qd = q[3 * stride];
uint32_t qe = q[4 * stride];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t qf = qe >> 22;
qe <<= 8;
qe |= qd >> 24;
qd <<= 6;
qd |= qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
uint32_t zd = 0;
uint32_t ze = 0;
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za |= ((qf & 0x001) >> 0) << 15;
zb |= ((qf & 0x002) >> 1) << 15;
zc |= ((qf & 0x004) >> 2) << 15;
zd |= ((qf & 0x008) >> 3) << 15;
ze |= ((qf & 0x010) >> 4) << 15;
za |= ((qf & 0x020) >> 5) << 31;
zb |= ((qf & 0x040) >> 6) << 31;
zc |= ((qf & 0x080) >> 7) << 31;
zd |= ((qf & 0x100) >> 8) << 31;
ze |= ((qf & 0x200) >> 9) << 31;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
q[3 * stride] = zd;
q[4 * stride] = ze;
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y32_ = __float2half_rn(1.0f / 32.0f);
const half2 y32 = __halves2half2(y32_, y32_);
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z32 = __halves2half2(z32_, z32_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
uint32_t qd = q_3;
uint32_t qe = q_4;
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
qa >>= 10;
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
qa >>= 5;
qa &= 0x00010001;
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
qb >>= 10;
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
qb >>= 4;
qb &= 0x00020002;
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
qc >>= 10;
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
qc >>= 3;
qc &= 0x00040004;
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
qd >>= 10;
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
qd >>= 2;
qd &= 0x00080008;
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
qe >>= 10;
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
qe >>= 1;
qe &= 0x00100010;
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hadd2( q3.as_half2, z1);
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hadd2( q6.as_half2, z1);
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
dq[ 8] = __hadd2( q8.as_half2, z1);
dq[ 9] = __hadd2( q9.as_half2, z1);
dq[10] = __hfma2(q10.as_half2, y32, z32);
dq[11] = __hadd2(q11.as_half2, z1);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y32, z32);
dq[14] = __hadd2(q14.as_half2, z1);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,44 @@
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_6bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_6bit_16
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,38 @@
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,51 @@
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
#endif

View file

@ -0,0 +1,32 @@
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
{
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
qs_h = __hmul(qs_h, qs_h);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}

View file

@ -0,0 +1,134 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t make_q_matrix
(
torch::Tensor q_weight,
torch::Tensor q_perm,
torch::Tensor q_invperm,
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
torch::Tensor temp_dq
)
{
TORCH_CHECK_DTYPE(q_weight, kInt);
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
int device = q_weight.device().index();
int width = q_weight.size(1);
int groups;
int height;
if (!q_scale.device().is_meta())
{
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
groups = q_scale.size(0);
height = q_invperm.size(0);
}
else
{
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
groups = gptq_qzeros.size(0);
height = q_weight.size(0) * 8;
}
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
QMatrix* m = new QMatrix
(
device,
height,
width,
groups,
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
return reinterpret_cast<uintptr_t> (m);
}
void gemm_half_q_half
(
torch::Tensor a,
uintptr_t b,
torch::Tensor c,
bool force_cuda
)
{
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
TORCH_CHECK_DTYPE(a, kHalf);
TORCH_CHECK_DTYPE(c, kHalf);
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
qm,
(half*) c.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
true,
NULL,
force_cuda
);
}
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,149 @@
def load_int(to, address, const=True):
if const:
return f"const __m256i {to} = _mm256_loadu_si256({address});"
else:
return f"__m256i {to} = _mm256_loadu_si256({address});"
def load_fp(to, address, const=True):
if const:
return f"const __m256 {to} = _mm256_loadu_ps({address});"
else:
return f"__m256 {to} = _mm256_loadu_ps({address});"
# to = a * b + c
def vfma(to, a, b, c):
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
def vsrli(to, a, b):
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
def vand(to, a, b):
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
def vbroadcast_fp(to, a):
return f"const __m256 {to} = _mm256_set1_ps({a});"
def vbroadcast_int32(to, a):
return f"__m256i {to} = _mm256_set1_epi32({a});"
def vsetzero(to):
return f"__m256 {to} = _mm256_setzero_ps();"
def vcvtepi32_ps(to, a):
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
def _256extractf128_ps(to, a, imm):
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
def _256castps256_ps128(to, a):
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
def _add_ps(to, a, b):
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
def _movehl_ps(to, a, b):
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
def _shuffle_ps(to, a, b, imm):
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
def _cvtss_f32(to, a):
return f"const float {to} = _mm_cvtss_f32({a});"
def _reduce8_acc(a, b, c, d, e, f, g, h):
res = ""
res += _256extractf128_ps("hi_quad0", a, 1)
res += _256extractf128_ps("hi_quad1", b, 1)
res += _256extractf128_ps("hi_quad2", c, 1)
res += _256extractf128_ps("hi_quad3", d, 1)
res += _256extractf128_ps("hi_quad4", e, 1)
res += _256extractf128_ps("hi_quad5", f, 1)
res += _256extractf128_ps("hi_quad6", g, 1)
res += _256extractf128_ps("hi_quad7", h, 1)
res += _256castps256_ps128("lo_quad0", a)
res += _256castps256_ps128("lo_quad1", b)
res += _256castps256_ps128("lo_quad2", c)
res += _256castps256_ps128("lo_quad3", d)
res += _256castps256_ps128("lo_quad4", e)
res += _256castps256_ps128("lo_quad5", f)
res += _256castps256_ps128("lo_quad6", g)
res += _256castps256_ps128("lo_quad7", h)
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
res += _add_ps("sum0", "sum_dual0", "hi0")
res += _add_ps("sum1", "sum_dual1", "hi1")
res += _add_ps("sum2", "sum_dual2", "hi2")
res += _add_ps("sum3", "sum_dual3", "hi3")
res += _add_ps("sum4", "sum_dual4", "hi4")
res += _add_ps("sum5", "sum_dual5", "hi5")
res += _add_ps("sum6", "sum_dual6", "hi6")
res += _add_ps("sum7", "sum_dual7", "hi7")
res += _cvtss_f32(f"f{a}", "sum0")
res += _cvtss_f32(f"f{b}", "sum1")
res += _cvtss_f32(f"f{c}", "sum2")
res += _cvtss_f32(f"f{d}", "sum3")
res += _cvtss_f32(f"f{e}", "sum4")
res += _cvtss_f32(f"f{f}", "sum5")
res += _cvtss_f32(f"f{g}", "sum6")
res += _cvtss_f32(f"f{h}", "sum7")
return res
acc_idx = 0
def _reduce_add(a):
global acc_idx
res = ""
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
acc_idx += 1
return res

View file

@ -0,0 +1,302 @@
#include <iostream>
#include "forward.h"
#include <cstring>
#include <algorithm>
#include <vector>
#include <chrono>
#include <fstream>
#define mymin(a,b) ((a)<(b)?(a):(b))
#define mymax(a,b) ((a)>(b)?(a):(b))
void print_matrix(std::string name, float* A, int N, int M){
std::cout<<name<<std::endl;
for(int i = 0; i < N; i++){
for(int j = 0; j < M; j++){
std::cout << A[i*M+j] << " ";
}
std::cout << std::endl;
}
std::cout<<std::endl;
}
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
// triple loop matmul and add bias
for (int i = 0; i < n; i++){
for (int j = 0; j < t; j++){
float sum = 0;
for (int k = 0; k < m; k++){
sum += A[i*m+k] * B[k*t+j];
}
C[i*t+j] += sum + bias[j];
}
}
}
void compute_reduction(float *in, float *out, int n, int m, int gs){
int ng;
if(gs == -1){
ng = 1;
gs = m;
}else{
ng = m/gs;
}
for(int i = 0; i < n; i++){
for(int j0 = 0; j0 < m; j0+=gs){
int j = j0/gs;
out[i*ng+j] = 0;
for(int j1 = j0; j1 < j0+gs; j1++){
out[i*ng+j] += in[i*m+j1];
}
}
}
}
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
for (int i1 = i0; i1 < i0+gs; i1++){
uint32_t acc = 0;
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
BQ[i1*m+j] = val;
}
}
}
}
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
if(bits == 3){
for (int i1 = i0; i1 < i0+gs; i1+=32){
uint32_t acc = 0;
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
int acc0 = 0, acc1 = 0, acc2 = 0;
acc0 |= temp0;
acc0 |= temp1;
acc0 |= temp2;
acc0 |= temp3;
acc0 |= temp4;
acc0 |= temp5;
acc0 |= temp6;
acc0 |= temp7;
acc0 |= temp8;
acc0 |= temp9;
acc0 |= temp10_0;
acc1 |= temp10_1;
acc1 |= temp11;
acc1 |= temp12;
acc1 |= temp13;
acc1 |= temp14;
acc1 |= temp15;
acc1 |= temp16;
acc1 |= temp17;
acc1 |= temp18;
acc1 |= temp19;
acc1 |= temp20;
acc1 |= temp21_0;
acc2 |= temp21_1;
acc2 |= temp22;
acc2 |= temp23;
acc2 |= temp24;
acc2 |= temp25;
acc2 |= temp26;
acc2 |= temp27;
acc2 |= temp28;
acc2 |= temp29;
acc2 |= temp30;
acc2 |= temp31;
BQ[(3*i1/32)*m+j] = acc0;
BQ[(3*i1/32+1)*m+j] = acc1;
BQ[(3*i1/32+2)*m+j] = acc2;
}
}else{
for (int i1 = i0; i1 < i0+gs; i1+=packed){
uint32_t acc = 0;
for (int i2 = i1; i2 < i1+packed; i2++){
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
acc = acc | (temp << (bits*(i2-i1)));
}
BQ[(i1/packed)*m+j] = acc;
}
}
}
}
}
int main(int argc, char *argv[]){
// read n m t from args
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
int n = atoi(argv[1]);
int m = atoi(argv[2]);
int t = atoi(argv[3]);
int bits = atoi(argv[4]);
int gs = atoi(argv[5]);
int ng;
if(gs == -1){
ng = 1;
}else{
ng = m/gs;
}
float* A = new float[n*m];
float* AB = new float[n*m];
float* B = new float[m*t];
float* BQS = new float[m*t];
float* scales = new float[t*ng];
float* zeros = new float[t*ng];
int* BQ = new int[m*t/8];
int* BQB = new int[m*t/8];
float* sums = new float[n*ng];
float* bias = new float[t];
float* C = new float[n*t];
float* CB = new float[n*t];
float* C2 = new float[n*t];
srand(1);
for (int i = 0; i < n*m; i++){
A[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t*m; i++){
B[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t; i++){
bias[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < n*t; i++){
C[i] = 0.0;
C2[i] = 0.0;
}
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
oracle_mmadd(A, BQS, bias, C, n, m, t);
pack_input(A,AB);
pack_qw(BQ,BQB);
pack_output(C,CB);
compute_reduction(A,sums,n,m,gs);
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
float norm = 0.0;
for (int i = 0; i < n*t; i++){
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
}
if(norm / (n*t) < 0.0001){
int iter = 30;
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
int num_runs = 15;
std::vector<long int> runs(num_runs);
for(int r = 0; r < num_runs; r++){
auto start = std::chrono::high_resolution_clock::now();
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
auto end = std::chrono::high_resolution_clock::now();
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
}
std::sort(runs.begin(), runs.end());
float cycles_final = runs[num_runs/2 + 1] / iter;
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}else{
float cycles_final = int(10e12);
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}
return 0;
}

View file

@ -0,0 +1,85 @@
def includes():
out = " \
#include <torch/all.h>\n \
#include <torch/python.h>\n \
#include <omp.h>\n \
#include <cmath>\n \
#include <immintrin.h>\n \
\n \
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
"
return out
def module(bits_list=[4, 2]):
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
for bits in bits_list:
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
# if oracle:
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
out += '}\n'
return out
def quant_scalar():
out = " \
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
//find scales and zeros arrays \n \
//quantize \n \
int pack = 32/bits;\n \
for (int j = 0; j < m; j++){\n \
for (int i = 0; i < n; i+=pack){\n \
uint32_t acc = 0;\n \
for (int ii = i; ii < i+pack; ii++){\n \
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
int temp = (int)ftemp;\n \
acc = acc | (temp << (bits*(ii-i)));\n \
}\n \
BQ[(i/pack)*m+j] = acc;\n \
//BQ[0] = acc;\n \
}\n \
}\n \
}\n \
\n \
void quant_scalar_cpu(\n \
torch::Tensor in, torch::Tensor out, \n \
torch::Tensor scales, torch::Tensor zeros, int bits\n \
) {\n \
\n \
int N = in.size(0);\n \
int M = in.size(1);\n \
\n \
float* input = in.data_ptr<float>(); \n \
float* s = scales.data_ptr<float>();\n \
float* z = zeros.data_ptr<float>();\n \
int* O = out.data_ptr<int>();\n \
\n \
quantize_scalar(input, O, s, z, N, M, bits);\n \
\n \
}\n"
return out

View file

@ -1,4 +1,9 @@
## <center>News or Update</center>
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
- 2023-05-30 - (Update) - support download/upload quantized model from/to 🤗 Hub.

View file

@ -146,7 +146,8 @@ def load_model_tokenizer(
use_safetensors: bool = False,
use_fast_tokenizer: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True
inject_fused_mlp: bool = True,
disable_exllama: bool = False
):
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
@ -176,7 +177,8 @@ def load_model_tokenizer(
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=False
warmup_triton=False,
disable_exllama=disable_exllama
)
return model, tokenizer
@ -234,6 +236,7 @@ def main():
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--no_inject_fused_attention", action="store_true")
parser.add_argument("--no_inject_fused_mlp", action="store_true")
parser.add_argument("--num_samples", type=int, default=10)
@ -275,7 +278,8 @@ def main():
use_safetensors=args.use_safetensors,
use_fast_tokenizer=args.use_fast_tokenizer,
inject_fused_attention=not args.no_inject_fused_attention,
inject_fused_mlp=not args.no_inject_fused_mlp
inject_fused_mlp=not args.no_inject_fused_mlp,
disable_exllama=args.disable_exllama
)
end = time.time()
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")

View file

@ -37,6 +37,7 @@ if __name__ == "__main__":
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -68,7 +69,8 @@ if __name__ == "__main__":
use_safetensors=args.use_safetensors,
trust_remote_code=args.trust_remote_code,
inject_fused_mlp=False,
inject_fused_attention=False
inject_fused_attention=False,
disable_exllama=args.disable_exllama
)
else:
from transformers import AutoModelForCausalLM

140
setup.py
View file

@ -1,67 +1,93 @@
import os
import platform
import sys
from pathlib import Path
from setuptools import setup, find_packages
from setuptools import setup, Extension, find_packages
import subprocess
import math
import platform
python_min_version = (3, 8, 0)
python_min_version_str = '.'.join(map(str, python_min_version))
if sys.version_info < python_min_version:
print(f"You are using Python {platform.python_version()}. Python >={python_min_version_str} is required.")
sys.exit(-1)
BUILD_CUDA_EXT = int(os.environ.get('BUILD_CUDA_EXT', '1')) == 1
if BUILD_CUDA_EXT:
try:
import torch
except:
print("torch is not installed, please install torch first!")
sys.exit(-1)
CUDA_VERSION = "".join(torch.version.cuda.split("."))
else:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", "").split("."))
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
common_setup_kwargs = {
"version": "0.3.2",
"version": "0.5.0.dev0",
"name": "auto_gptq",
"author": "PanQiWei",
"description": "An easy-to-use LLMs quantization package with user-friendly apis, based on GPTQ algorithm.",
"long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
"long_description_content_type": "text/markdown",
"url": "https://github.com/PanQiWei/AutoGPTQ",
"keywords": ["gptq", "quantization", "large-language-models", "pytorch", "transformers"],
"keywords": ["gptq", "quantization", "large-language-models", "transformers"],
"platforms": ["windows", "linux"],
"classifiers": [
"Environment :: GPU :: NVIDIA CUDA :: 11.7",
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"License :: OSI Approved :: MIT License",
"Natural Language :: Chinese (Simplified)",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
],
"python_requires": f">={python_min_version_str}"
]
}
if CUDA_VERSION:
PYPI_RELEASE = os.environ.get('PYPI_RELEASE', None)
BUILD_CUDA_EXT = int(os.environ.get('BUILD_CUDA_EXT', '1')) == 1
if BUILD_CUDA_EXT:
try:
import torch
except:
print("Building cuda extension requires PyTorch(>=1.13.0) been installed, please install PyTorch first!")
sys.exit(-1)
CUDA_VERSION = None
ROCM_VERSION = os.environ.get('ROCM_VERSION', None)
if ROCM_VERSION and not torch.version.hip:
print(
f"Trying to compile auto-gptq for RoCm, but PyTorch {torch.__version__} "
"is installed without RoCm support."
)
sys.exit(-1)
if not ROCM_VERSION:
default_cuda_version = torch.version.cuda
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", default_cuda_version).split("."))
if ROCM_VERSION:
common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}"
else:
if not CUDA_VERSION:
print(
f"Trying to compile auto-gptq for CUDA, byt Pytorch {torch.__version__} "
"is installed without CUDA support."
)
sys.exit(-1)
# For the PyPI release, the version is simply x.x.x to comply with PEP 440.
if not PYPI_RELEASE:
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}"
requirements = [
"accelerate>=0.19.0",
"datasets",
"sentencepiece",
"numpy",
"rouge",
"gekko",
"torch>=1.13.0",
"safetensors",
"transformers>=4.31.0",
"peft"
"peft",
"tqdm",
]
extras_require = {
"triton": ["triton>=2.0.0"]
"triton": ["triton==2.0.0"],
"test": ["parameterized"]
}
include_dirs = ["autogptq_cuda"]
@ -69,8 +95,16 @@ include_dirs = ["autogptq_cuda"]
additional_setup_kwargs = dict()
if BUILD_CUDA_EXT:
from torch.utils import cpp_extension
if platform.system() != 'Windows':
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
if not ROCM_VERSION:
from distutils.sysconfig import get_python_lib
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
print("conda_cuda_include_dir", conda_cuda_include_dir)
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
print(f"appending conda cuda include dir {conda_cuda_include_dir}")
@ -78,19 +112,64 @@ if BUILD_CUDA_EXT:
cpp_extension.CUDAExtension(
"autogptq_cuda_64",
[
"autogptq_cuda/autogptq_cuda_64.cpp",
"autogptq_cuda/autogptq_cuda_kernel_64.cu"
"autogptq_extension/cuda_64/autogptq_cuda_64.cpp",
"autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu"
]
),
cpp_extension.CUDAExtension(
"autogptq_cuda_256",
[
"autogptq_cuda/autogptq_cuda_256.cpp",
"autogptq_cuda/autogptq_cuda_kernel_256.cu"
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
]
)
]
if platform.system() != 'Windows':
extensions.append(
cpp_extension.CppExtension(
"cQIGen",
[
'autogptq_extension/qigen/backend.cpp'
],
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
)
)
if os.name == "nt":
# On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path is None:
raise ValueError("The environment variable CUDA_PATH must be set to the path to the CUDA install when installing from source on Windows systems.")
extra_link_args = ["-L", f"{cuda_path}/lib/x64/cublas.lib"]
else:
extra_link_args = []
extensions.append(
cpp_extension.CUDAExtension(
"exllama_kernels",
[
"autogptq_extension/exllama/exllama_ext.cpp",
"autogptq_extension/exllama/cuda_buffers.cu",
"autogptq_extension/exllama/cuda_func/column_remap.cu",
"autogptq_extension/exllama/cuda_func/q4_matmul.cu",
"autogptq_extension/exllama/cuda_func/q4_matrix.cu"
],
extra_link_args=extra_link_args
)
)
extensions.append(
cpp_extension.CUDAExtension(
"exllamav2_kernels",
[
"autogptq_extension/exllamav2/ext.cpp",
"autogptq_extension/exllamav2/cuda/q_matrix.cu",
"autogptq_extension/exllamav2/cuda/q_gemm.cu",
],
extra_link_args=extra_link_args
)
)
additional_setup_kwargs = {
"ext_modules": extensions,
"cmdclass": {'build_ext': cpp_extension.BuildExtension}
@ -101,5 +180,6 @@ setup(
install_requires=requirements,
extras_require=extras_require,
include_dirs=include_dirs,
python_requires=">=3.8.0",
**common_setup_kwargs
)

594
tests/test_q4.py Normal file
View file

@ -0,0 +1,594 @@
import unittest
from parameterized import parameterized
import torch
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from auto_gptq import AutoGPTQForCausalLM, exllama_set_max_input_length
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.modeling._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from transformers import AutoTokenizer
def get_diff(a, ref):
eps = 1e-6
return f"Maxdiff: {(a - ref).abs().max()}, Mean relative diff: {((a - ref).abs() / (ref.abs() + eps)).mean()}"
class TestsQ4Exllama(unittest.TestCase):
# reference generated with cuda_old
REFERENCE = torch.Tensor([5.8398, 6.8555, 7.2734, 6.4219, 6.2070, 5.8203, 6.5664, 6.4219, 6.2148,
5.3281, 5.7578, 7.5312, 8.1016, 6.1133, 7.2031, 6.6484, 6.5156, 6.0117,
6.0312, 6.1914, 6.2109, 6.8125, 5.8125, 7.1172, 7.3125, 6.7305, 5.9961,
6.5117, 6.1914, 5.9648, 7.1680, 6.4766, 7.2070, 6.5469, 6.7734, 6.4219,
6.8086, 7.0469, 5.9297, 6.4727, 6.2539, 5.9570, 7.2383, 5.8945, 6.0820,
5.7969, 7.1094, 6.2188, 6.7500, 7.3555, 6.2930, 6.7734, 5.9219, 7.4805,
6.8750, 6.4102, 6.5898, 6.5469, 7.6016, 6.7461, 5.9492, 7.2227, 5.8164,
5.4570, 6.2930, 7.3984, 6.0938, 7.3984, 5.9609, 6.3516, 6.5664, 5.7969,
7.1250, 6.0781, 6.7930, 5.9492, 6.1641, 6.5898, 6.0586, 6.3359, 6.7930,
7.0469, 6.0664, 6.3320, 5.4414, 6.7617, 5.1641, 7.2891, 6.8516, 6.5312,
5.6914, 7.3711, 6.8203, 5.9492, 7.0781, 6.3164, 7.1992, 7.1133, 7.4219,
7.5586, 7.1836, 6.9102, 6.4844, 6.9805, 6.1953, 6.5156, 5.4844, 6.6602,
6.6719, 7.9844, 6.4727, 6.6367, 6.2227, 6.4531, 5.0625, 6.4609, 6.7031,
6.6445, 6.5234, 6.8633, 6.6055, 5.6055, 6.4453, 7.2617, 6.3945, 6.6367,
6.1055, 7.0664, 6.0820, 6.6875, 6.1445, 6.8672, 6.2070, 6.8828, 6.1484,
6.7070, 6.8516, 6.2734, 7.1055, 7.0586, 6.9648, 5.9727, 6.1016, 6.8750,
7.0078, 7.1523, 5.7383, 5.9531, 6.5508, 7.5352, 6.1602, 6.2578, 6.3906,
5.7383, 6.7031, 5.7344, 6.3516, 5.2852, 7.5312, 6.4531, 6.6406, 6.2266,
6.1094, 5.9102, 5.7617, 6.3789, 7.0508, 6.3750, 6.3320, 6.8555, 6.7266,
7.0352, 7.7695, 6.3984, 6.5039, 6.8320, 6.1602, 6.0312, 6.3828, 6.9023,
7.4336, 7.3711, 6.1016, 7.0703, 6.3281, 6.8281, 6.4922, 5.9453, 5.1016,
6.7188, 6.1406, 6.6289, 7.2695, 6.2070, 6.7070, 7.2930, 7.1836, 6.3828,
6.1992, 6.7070, 7.8008, 7.7773, 5.6602, 7.0273, 6.6172, 6.0898, 5.3516,
7.3359, 5.9727, 6.0078, 7.0586, 6.3086, 6.8555, 7.2617, 7.3477, 6.3828,
7.1133, 6.6328, 7.3516, 6.9141, 7.2031, 6.9805, 6.1719, 6.7812, 8.3047,
6.5898, 6.3633, 6.2539, 7.2773, 6.5938, 6.4141, 6.8203, 6.8906, 7.8828,
5.9609, 6.4180, 7.3984, 5.7539, 7.1758, 6.6641, 6.9062, 6.2578, 7.5508,
6.1719, 6.5742, 5.9375, 6.7891, 6.2109, 6.5039, 6.8750, 6.2031, 6.8828,
7.1094, 5.9570, 7.2969, 6.6797, 6.8828, 5.5430, 6.9648, 5.8398, 6.5430,
6.3945, 6.5664, 5.8086, 6.6172, 7.0586, 6.8867, 6.0820, 5.8125, 6.7070,
7.5742, 6.2578, 6.1328, 6.5391, 5.4531, 6.8242, 6.6953, 6.8008, 6.3398,
6.4805, 7.2266, 6.3281, 6.6875, 6.4688, 5.9414, 7.4297, 5.8711, 6.0625,
5.8750, 6.5664, 5.8867, 6.3477, 6.1133, 6.9453, 5.0547, 6.7812, 6.4922,
7.2422, 5.4688, 6.2109, 7.2148, 6.1758, 5.9297, 7.1953, 5.5195, 6.3203,
5.9961, 7.9297, 6.2695, 6.4414, 6.7266, 7.1875, 7.3203, 5.4062, 6.0625,
7.0898, 5.3828, 5.6133, 6.0742, 6.6836, 5.7109, 7.2852, 7.7539, 7.5820,
6.4258, 5.9336, 6.3750, 6.3555, 7.5469, 6.2539, 6.5898, 6.4102, 7.0469,
5.7344, 7.2031, 6.7969, 5.6836, 7.6523, 6.9297, 7.8672, 6.4766, 6.3008,
7.0977, 6.5430, 7.0938, 5.8398, 6.9883, 6.5312, 6.3203, 6.3594, 5.4062,
6.9688, 5.7930, 6.3164, 6.5547, 7.1992, 5.8750, 6.3008, 6.7930, 6.0391,
7.4766, 6.6094, 6.5625, 5.9805, 6.2422, 7.2109, 6.6875, 5.3047, 7.6211,
5.9453, 6.5625, 6.1641, 6.1250, 6.5977, 7.7422, 7.0742, 5.6875, 6.2656,
6.6250, 6.8945, 5.7070, 6.3203, 5.7500, 6.2695, 6.2773, 6.8516, 6.4883,
7.0000, 6.7578, 6.1875, 5.9844, 5.5703, 6.7188, 5.5273, 5.3438, 7.2500,
6.7852, 6.5195, 6.8125, 6.0664, 6.7852, 7.0000, 7.0781, 6.8477, 7.2930,
6.3438, 7.1523, 6.3281, 6.8047, 7.3203, 5.3359, 6.1484, 6.5586, 7.3828,
6.2344, 7.1523, 6.4102, 5.5898, 7.0195, 7.1172, 5.8008, 6.5742, 6.2891,
8.0312, 6.9023, 6.5898, 7.1953, 6.7266, 6.0078, 5.5430, 6.4766, 6.4258,
5.9648, 8.0859, 5.0547, 7.2188, 7.4375, 6.5156, 5.9922, 6.3281, 6.2852,
6.7734, 6.2461, 6.9805, 5.4648, 5.8867, 6.8242, 6.3008, 6.3281, 7.3047,
7.1836, 6.5195, 6.6328, 6.7188, 5.4336, 6.5078, 5.3477, 5.5508, 7.3125,
5.8750, 6.5195, 6.2383, 6.3594, 6.0898, 6.4141, 5.9844, 6.6250, 7.7109,
6.0391, 7.2344, 5.9453, 5.9453, 7.0586, 5.6641, 7.2773, 6.5195, 7.2227,
6.3359, 5.3203, 6.4375, 7.2383, 6.4023, 6.2148, 7.3750, 5.8164, 6.2109,
6.5430, 5.8164, 6.1680, 6.7656, 6.0820, 6.1094, 6.5312, 6.8906, 6.8320,
6.1289, 6.3125, 7.6797, 6.3008, 6.0000, 7.3320, 6.7852, 6.9297, 6.6328,
6.2266, 5.1602, 6.2031, 7.0547, 5.9492, 6.0703, 6.0977, 6.8086, 6.0742,
6.0195, 7.0625, 6.5781, 5.7461, 6.1562, 7.0430, 6.7148, 6.5312, 6.5820,
6.4570, 7.5508, 5.6289, 6.0547, 6.5000, 7.3125, 5.8477, 5.9297, 6.2578,
6.0078, 5.9922, 7.3398, 7.4922, 7.8906, 7.5547, 5.4648, 6.5156, 6.3242,
6.1094, 6.9219, 6.7227, 6.6836, 7.4023, 5.9648, 7.2383, 6.7695, 6.6797,
7.0547, 6.3047, 6.4688, 6.9961, 6.0391, 5.9727, 6.8398, 6.7422, 5.7656,
5.4766, 6.7852, 7.0820, 5.3516, 7.6523, 5.1562, 6.6445, 6.1211, 6.2695,
6.0703, 6.3594, 6.4062, 6.3398, 5.7578, 6.5391, 6.2500, 6.5742, 6.5000,
7.5625, 7.0117, 6.5547, 7.1250, 6.4453, 6.6094, 6.1875, 6.4219, 6.6172,
6.4336, 6.5703, 6.1758, 6.4219, 6.6016, 6.7383, 6.7070, 6.1328, 5.5586,
6.6367, 6.3789, 6.2578, 5.5039, 6.6172, 6.4648, 5.8086, 7.2031, 5.8125,
6.3711, 7.6758, 7.1289, 5.8086, 6.3008, 6.2109, 6.1602, 6.1797, 7.2305,
6.7266, 6.2422, 5.6719, 6.7070, 6.9414, 6.8594, 7.4023, 7.2109, 6.0156,
6.6680, 6.6172, 7.1250, 6.6523, 6.9531, 6.7617, 6.4961, 6.9414, 5.7188,
7.6367, 6.5469, 6.2305, 6.4414, 7.4648, 5.9102, 6.2461, 6.1367, 6.8203,
6.5703, 6.8867, 7.0000, 6.7539, 6.1719, 6.5469, 6.2422, 5.4297, 5.7305,
5.1641, 6.1875, 7.0312, 6.6484, 6.0234, 7.4102, 6.8711, 6.3086, 6.3711,
6.7344, 6.6992, 5.9766, 7.3906, 7.1875, 6.4883, 6.3984, 7.3438, 6.9688,
6.9062, 6.4375, 6.7891, 7.0117, 6.4883, 5.7500, 7.0898, 7.0742, 6.7070,
5.8750, 6.0469, 6.6445, 5.2773, 6.8984, 6.1641, 7.0508, 7.4609, 5.0273,
6.7734, 6.4531, 5.7656, 6.5312, 7.4648, 6.1250, 6.5625, 7.1367, 6.0625,
6.1211, 6.9766, 6.6758, 6.3164, 6.8828, 6.8203, 6.7500, 6.5352, 7.3008,
6.7852, 6.1914, 5.0508, 6.7188, 7.1172, 6.8008, 6.8086, 5.4883, 6.9180,
6.5742, 6.1719, 7.0469, 7.1523, 5.9492, 5.8594, 6.8320, 6.1719, 6.2031,
6.8398, 7.3008, 6.6289, 6.4922, 6.0000, 5.4766, 6.3320, 6.5117, 6.2812,
7.5742, 6.3516, 7.0039, 6.4570, 7.1523, 7.6289, 6.2578, 7.1875, 6.4844,
5.7930, 6.7070, 7.5508, 7.1797, 6.0430, 6.8711, 6.5742, 7.5781, 6.4766,
6.5391, 6.9453, 6.1992, 6.6367, 6.2812, 6.0234, 6.6953, 7.0312, 6.2031,
6.5625, 6.6719, 6.1719, 6.5586, 5.7031, 7.4609, 6.6211, 7.7227, 6.9141,
6.0469, 6.2500, 5.3828, 6.0078, 5.8164, 5.8867, 6.1523, 6.6523, 6.6953,
7.3125, 6.4844, 5.9570, 5.9531, 6.2109, 5.5039, 6.5117, 6.8203, 6.6133,
6.4766, 5.9297, 7.1445, 7.1914, 6.0117, 6.8281, 6.7422, 6.1328, 6.9805,
6.5625, 6.9180, 7.1133, 7.3359, 5.7617, 5.8711, 6.4961, 6.5859, 6.2422,
6.5273, 6.7461, 6.6992, 6.7695, 6.6289, 5.9453, 5.9805, 7.1172, 6.6719,
6.0039, 7.6875, 6.7812, 7.8359, 6.9531, 7.4336, 7.6602, 6.8164, 7.3945,
7.1602, 6.8789, 5.0078, 6.0547, 6.8086, 6.7070, 6.4688, 6.4492, 6.6172,
5.5625, 6.6914, 6.4297, 5.7461, 5.3359, 6.8750, 6.4609, 7.4062, 5.2070,
6.0820, 6.7383, 6.5703, 6.1797, 6.7070, 6.5977, 5.9961, 6.6328, 6.9375,
6.3906, 6.6484, 4.9609, 6.6445, 6.5898, 7.1875, 7.5195, 6.7969, 6.1367,
6.8906, 7.4297, 6.3633, 6.0508, 6.5000, 6.4648, 6.7539, 6.7109, 5.8086,
6.6016, 7.1133, 4.8672, 6.6367, 6.1641, 5.1758, 6.9453, 6.3242, 7.0664,
6.4805, 6.3516, 6.7383, 8.4688, 6.7305, 5.9844, 6.5938, 7.2969, 6.5977,
7.5898, 6.2969, 6.8672, 6.6680, 7.1289, 6.6875, 5.4258, 8.1875, 8.0391,
7.7969, 6.6445, 7.0703, 7.3359, 6.9805, 6.6328, 6.5352, 6.2422, 5.5820,
6.8633, 6.8047, 6.5703, 6.0117, 6.7539, 7.1719, 6.8438, 7.3633, 6.6016,
7.2070, 6.4727, 5.8008, 7.4062, 7.4805, 6.6445, 5.9023, 6.3984, 6.9961,
6.6680, 6.8242, 6.7148, 6.6172, 6.9727, 6.8320, 5.9766, 6.6133, 5.5977,
6.7773, 7.3906, 6.9219, 7.0781, 6.6914, 5.7539, 6.7969, 6.8008, 5.8047,
7.1055, 6.4961, 6.0352, 5.6211, 7.4414, 7.0703, 6.1172, 6.7461, 6.4492,
7.7148, 6.4258, 6.0039, 6.5156, 7.2188, 7.4531, 7.4844, 7.5938, 7.4023,
6.7617, 6.0078, 6.3320, 5.8906, 7.5977, 5.6523, 6.7734, 6.3008, 5.2227,
7.1719, 7.1289, 6.6602, 5.4609, 7.0312, 6.0820, 6.1719, 6.0000, 6.5547,
6.6328, 7.0547, 7.0859, 6.2656, 5.5234, 6.0273, 6.7891, 7.1875, 6.9531,
6.8203, 6.3516, 6.1172, 6.4648, 6.9180, 7.3906, 6.2812, 5.7109, 6.1484,
6.9102, 6.8711, 7.0156, 6.1445, 5.8867, 6.3828, 5.9961, 6.6914, 6.7891,
7.0820, 6.6719, 6.9297, 6.3750, 6.7578, 6.4883, 6.2227, 6.2305, 6.0508,
6.6484, 5.7578, 7.2070, 7.2383, 6.9375, 7.2578, 6.5312, 6.0312, 6.7930,
6.2578, 7.0625, 7.2148, 6.4961, 7.0703, 6.4727, 7.3906]).to(torch.float16)
def test_exllama(self):
group_size = 128
m = 1
k = 1024
n = 1024
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=False, disable_exllamav2=True)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=k,
outfeatures=n,
bias=False,
)
self.assertTrue(isinstance(linear, QuantLinear))
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.eval()
linear = linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
max_inner_outer_dim = max(k, n)
max_dq_buffer_size = linear.infeatures * linear.outfeatures
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
with torch.no_grad():
res = linear(inp)[0][0]
reference = self.REFERENCE.to(device)
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
def test_exllama_buffer_size(self):
prompt = "I am in Paris and" * 450
device = torch.device("cuda:0")
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
self.assertTrue(inp["input_ids"].shape[1] > EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) # 2048 is the default max_input_length
with self.assertRaises(RuntimeError) as cm:
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
model_q = exllama_set_max_input_length(model_q, 4096)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
model_q = exllama_set_max_input_length(model_q, 1034)
with self.assertRaises(RuntimeError) as cm:
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
def test_generation_no_act_order(self):
prompt = "I am in Paris and"
device = torch.device("cuda:0")
# Reference generated with the cuda-old kernel
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
model_basename = "model"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
predicted_text = tokenizer.decode(res[0])
self.assertEqual(predicted_text, reference_output)
def test_generation_with_act_order(self):
prompt = "I am in Paris and"
device = torch.device("cuda:0")
# Reference generated with the cuda-old kernel
reference_output = "<s> I am in Paris and it is a beautiful day. I am sitting in a café, drinking coffee and writing this book. I am surrounded by the sights and sounds of the city, and I am filled with a sense of contentment and gratitude.\n\nI am grateful for the opportunity to live and"
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
predicted_text = tokenizer.decode(res[0])
self.assertEqual(predicted_text, reference_output)
def test_multigpu(self):
# TODO
pass
class TestsQ4CUDA(unittest.TestCase):
REFERENCE_OLD_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.8008, 1.9688, 1.3262, 1.7627, 1.8164, 1.9307,
1.8574, 1.5449, 1.5293, 1.6074, 1.5566, 1.8545, 1.6582, 1.8838, 2.0215,
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
1.1826, 1.8174, 2.1191, 1.6641, 2.0586, 1.6182, 1.7627, 1.7920, 1.4424,
2.0723, 1.6865, 1.2979, 2.0840, 1.6729, 1.9648, 2.1602, 1.6006, 1.2773,
2.2129, 1.8057, 1.7285, 1.6621, 1.6475, 1.4805, 1.7959, 1.5010, 0.8643,
2.6680, 2.0918, 1.8555, 1.9795, 1.3271, 1.8359, 1.6338, 1.9766, 1.7881,
1.6025, 1.7637, 1.7012, 1.7852, 1.5674, 0.8091, 1.7188, 1.6123, 1.8525,
1.4434, 1.9590, 1.5801, 1.4209, 1.7178, 1.8408, 2.4141, 1.9658, 1.4922,
2.1992, 1.9473, 1.8047, 1.2979, 1.6396, 1.6221, 1.5020, 1.9941, 1.7725,
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7197, 1.7686,
1.4160, 1.7275, 2.1738, 1.9609, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
2.1113, 1.7227, 1.5811, 1.7607, 2.2773, 1.8945, 1.4111, 1.5801, 1.7744,
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7803, 1.7744, 1.5312,
1.2695, 1.9209, 2.0469, 1.6777, 2.5215, 1.8389, 1.7598, 1.5498, 1.6807,
1.7324, 1.5938, 1.9268, 1.7734, 1.4463, 2.0391, 2.0527, 2.2129, 1.6787,
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7080, 1.1611, 1.8584,
2.4570, 1.6016, 1.4834, 1.1777, 1.7969, 1.8955, 1.8906, 1.6738, 1.7510,
1.4316, 1.8340, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
1.8887, 1.5137, 1.4648, 1.6406, 1.7188, 2.2656, 1.5801, 2.1484, 2.0625,
2.0098, 1.7549, 1.1768, 1.4385, 2.0723, 1.6172, 1.7832, 1.8301, 1.6064,
1.5215, 1.9297, 2.3750, 2.1504, 1.7070, 1.1289, 1.4473, 1.5674, 1.6836,
2.2930, 1.1221, 1.5557, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
1.6348, 2.3379, 2.4414, 1.8857, 2.0039, 1.4844, 1.5488, 1.6514, 2.3711,
1.9941, 2.3066, 1.4287, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
2.1309, 1.2666, 1.1523, 1.9561, 1.8584, 1.9746, 1.5986, 1.9688, 2.1973,
1.1523, 2.3281, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
REFERENCE_OLD_NO_HALF = torch.Tensor([1.5332, 2.1250, 1.7910, 1.7998, 1.9678, 1.3262, 1.7617, 1.8154, 1.9307,
1.8574, 1.5449, 1.5293, 1.6074, 1.5557, 1.8545, 1.6582, 1.8838, 2.0195,
1.8525, 1.2920, 1.9561, 2.2617, 1.7891, 2.2656, 1.6543, 2.0566, 1.4756,
1.1826, 1.8164, 2.1191, 1.6641, 2.0586, 1.6182, 1.7617, 1.7920, 1.4424,
2.0723, 1.6865, 1.2969, 2.0840, 1.6729, 1.9639, 2.1602, 1.5996, 1.2773,
2.2129, 1.8057, 1.7275, 1.6621, 1.6475, 1.4805, 1.7949, 1.5010, 0.8643,
2.6680, 2.0918, 1.8545, 1.9795, 1.3271, 1.8350, 1.6338, 1.9766, 1.7881,
1.6025, 1.7637, 1.7012, 1.7842, 1.5664, 0.8086, 1.7188, 1.6113, 1.8516,
1.4434, 1.9590, 1.5801, 1.4209, 1.7168, 1.8408, 2.4141, 1.9658, 1.4922,
2.1973, 1.9463, 1.8047, 1.2979, 1.6396, 1.6221, 1.5010, 1.9941, 1.7725,
1.6064, 1.5449, 1.8418, 1.2656, 1.4824, 1.7734, 2.0098, 1.7188, 1.7686,
1.4160, 1.7266, 2.1738, 1.9600, 1.7686, 1.6396, 2.1465, 1.2188, 1.2002,
2.1113, 1.7227, 1.5811, 1.7598, 2.2773, 1.8936, 1.4102, 1.5801, 1.7734,
2.0684, 2.1621, 1.8027, 1.1045, 1.9648, 2.2402, 2.0742, 1.3330, 1.5840,
2.1465, 2.0176, 1.5068, 1.9834, 1.7725, 1.5527, 1.7793, 1.7744, 1.5312,
1.2695, 1.9209, 2.0469, 1.6777, 2.5195, 1.8389, 1.7598, 1.5498, 1.6797,
1.7324, 1.5928, 1.9258, 1.7734, 1.4463, 2.0391, 2.0508, 2.2129, 1.6787,
2.0586, 1.8975, 1.5713, 1.6992, 1.8770, 1.7207, 1.7070, 1.1602, 1.8584,
2.4570, 1.6016, 1.4834, 1.1777, 1.7959, 1.8955, 1.8906, 1.6738, 1.7510,
1.4316, 1.8330, 2.2461, 1.7744, 2.1934, 1.4824, 1.8828, 1.6387, 2.4629,
1.8887, 1.5137, 1.4648, 1.6406, 1.7178, 2.2637, 1.5801, 2.1484, 2.0605,
2.0098, 1.7539, 1.1768, 1.4375, 2.0723, 1.6162, 1.7832, 1.8291, 1.6064,
1.5215, 1.9297, 2.3750, 2.1504, 1.7061, 1.1289, 1.4473, 1.5674, 1.6836,
2.2930, 1.1221, 1.5547, 1.7559, 1.8281, 2.0703, 1.9443, 2.0684, 2.2988,
1.6348, 2.3379, 2.4414, 1.8857, 2.0020, 1.4834, 1.5488, 1.6514, 2.3711,
1.9941, 2.3047, 1.4277, 2.1777, 1.6445, 1.6025, 1.5938, 1.5508, 1.9502,
2.1309, 1.2666, 1.1514, 1.9551, 1.8584, 1.9746, 1.5986, 1.9688, 2.1953,
1.1514, 2.3262, 1.2451, 1.8447, 2.2051, 1.5254, 1.5342, 2.1016, 1.6523,
1.6279, 1.1680, 1.3037, 2.1035]).to(torch.float16)
@parameterized.expand([False, True])
def test_cuda_old(self, use_half2: bool):
group_size = 128
# test the 256 kernel (in_features % 256 == 0 and out_features % 256 == 0)
m = 1
k = 256
n = 256
device = "cuda"
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllamav2=True)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=k,
outfeatures=n,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear.use_cuda_fp16 = use_half2
self.assertTrue(linear.autogptq_cuda_available)
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
linear = linear.eval()
linear = linear.to(device)
with torch.no_grad():
res = linear(inp)[0][0]
if use_half2:
reference = self.REFERENCE_OLD_HALF.to(device)
else:
reference = self.REFERENCE_OLD_NO_HALF.to(device)
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))
class TestsQ4ExllamaV2(unittest.TestCase):
# reference generated with cuda_old
REFERENCE = torch.Tensor([5.8398, 6.8555, 7.2734, 6.4219, 6.2070, 5.8203, 6.5664, 6.4219, 6.2148,
5.3281, 5.7578, 7.5312, 8.1016, 6.1133, 7.2031, 6.6484, 6.5156, 6.0117,
6.0312, 6.1914, 6.2109, 6.8125, 5.8125, 7.1172, 7.3125, 6.7305, 5.9961,
6.5117, 6.1914, 5.9648, 7.1680, 6.4766, 7.2070, 6.5469, 6.7734, 6.4219,
6.8086, 7.0469, 5.9297, 6.4727, 6.2539, 5.9570, 7.2383, 5.8945, 6.0820,
5.7969, 7.1094, 6.2188, 6.7500, 7.3555, 6.2930, 6.7734, 5.9219, 7.4805,
6.8750, 6.4102, 6.5898, 6.5469, 7.6016, 6.7461, 5.9492, 7.2227, 5.8164,
5.4570, 6.2930, 7.3984, 6.0938, 7.3984, 5.9609, 6.3516, 6.5664, 5.7969,
7.1250, 6.0781, 6.7930, 5.9492, 6.1641, 6.5898, 6.0586, 6.3359, 6.7930,
7.0469, 6.0664, 6.3320, 5.4414, 6.7617, 5.1641, 7.2891, 6.8516, 6.5312,
5.6914, 7.3711, 6.8203, 5.9492, 7.0781, 6.3164, 7.1992, 7.1133, 7.4219,
7.5586, 7.1836, 6.9102, 6.4844, 6.9805, 6.1953, 6.5156, 5.4844, 6.6602,
6.6719, 7.9844, 6.4727, 6.6367, 6.2227, 6.4531, 5.0625, 6.4609, 6.7031,
6.6445, 6.5234, 6.8633, 6.6055, 5.6055, 6.4453, 7.2617, 6.3945, 6.6367,
6.1055, 7.0664, 6.0820, 6.6875, 6.1445, 6.8672, 6.2070, 6.8828, 6.1484,
6.7070, 6.8516, 6.2734, 7.1055, 7.0586, 6.9648, 5.9727, 6.1016, 6.8750,
7.0078, 7.1523, 5.7383, 5.9531, 6.5508, 7.5352, 6.1602, 6.2578, 6.3906,
5.7383, 6.7031, 5.7344, 6.3516, 5.2852, 7.5312, 6.4531, 6.6406, 6.2266,
6.1094, 5.9102, 5.7617, 6.3789, 7.0508, 6.3750, 6.3320, 6.8555, 6.7266,
7.0352, 7.7695, 6.3984, 6.5039, 6.8320, 6.1602, 6.0312, 6.3828, 6.9023,
7.4336, 7.3711, 6.1016, 7.0703, 6.3281, 6.8281, 6.4922, 5.9453, 5.1016,
6.7188, 6.1406, 6.6289, 7.2695, 6.2070, 6.7070, 7.2930, 7.1836, 6.3828,
6.1992, 6.7070, 7.8008, 7.7773, 5.6602, 7.0273, 6.6172, 6.0898, 5.3516,
7.3359, 5.9727, 6.0078, 7.0586, 6.3086, 6.8555, 7.2617, 7.3477, 6.3828,
7.1133, 6.6328, 7.3516, 6.9141, 7.2031, 6.9805, 6.1719, 6.7812, 8.3047,
6.5898, 6.3633, 6.2539, 7.2773, 6.5938, 6.4141, 6.8203, 6.8906, 7.8828,
5.9609, 6.4180, 7.3984, 5.7539, 7.1758, 6.6641, 6.9062, 6.2578, 7.5508,
6.1719, 6.5742, 5.9375, 6.7891, 6.2109, 6.5039, 6.8750, 6.2031, 6.8828,
7.1094, 5.9570, 7.2969, 6.6797, 6.8828, 5.5430, 6.9648, 5.8398, 6.5430,
6.3945, 6.5664, 5.8086, 6.6172, 7.0586, 6.8867, 6.0820, 5.8125, 6.7070,
7.5742, 6.2578, 6.1328, 6.5391, 5.4531, 6.8242, 6.6953, 6.8008, 6.3398,
6.4805, 7.2266, 6.3281, 6.6875, 6.4688, 5.9414, 7.4297, 5.8711, 6.0625,
5.8750, 6.5664, 5.8867, 6.3477, 6.1133, 6.9453, 5.0547, 6.7812, 6.4922,
7.2422, 5.4688, 6.2109, 7.2148, 6.1758, 5.9297, 7.1953, 5.5195, 6.3203,
5.9961, 7.9297, 6.2695, 6.4414, 6.7266, 7.1875, 7.3203, 5.4062, 6.0625,
7.0898, 5.3828, 5.6133, 6.0742, 6.6836, 5.7109, 7.2852, 7.7539, 7.5820,
6.4258, 5.9336, 6.3750, 6.3555, 7.5469, 6.2539, 6.5898, 6.4102, 7.0469,
5.7344, 7.2031, 6.7969, 5.6836, 7.6523, 6.9297, 7.8672, 6.4766, 6.3008,
7.0977, 6.5430, 7.0938, 5.8398, 6.9883, 6.5312, 6.3203, 6.3594, 5.4062,
6.9688, 5.7930, 6.3164, 6.5547, 7.1992, 5.8750, 6.3008, 6.7930, 6.0391,
7.4766, 6.6094, 6.5625, 5.9805, 6.2422, 7.2109, 6.6875, 5.3047, 7.6211,
5.9453, 6.5625, 6.1641, 6.1250, 6.5977, 7.7422, 7.0742, 5.6875, 6.2656,
6.6250, 6.8945, 5.7070, 6.3203, 5.7500, 6.2695, 6.2773, 6.8516, 6.4883,
7.0000, 6.7578, 6.1875, 5.9844, 5.5703, 6.7188, 5.5273, 5.3438, 7.2500,
6.7852, 6.5195, 6.8125, 6.0664, 6.7852, 7.0000, 7.0781, 6.8477, 7.2930,
6.3438, 7.1523, 6.3281, 6.8047, 7.3203, 5.3359, 6.1484, 6.5586, 7.3828,
6.2344, 7.1523, 6.4102, 5.5898, 7.0195, 7.1172, 5.8008, 6.5742, 6.2891,
8.0312, 6.9023, 6.5898, 7.1953, 6.7266, 6.0078, 5.5430, 6.4766, 6.4258,
5.9648, 8.0859, 5.0547, 7.2188, 7.4375, 6.5156, 5.9922, 6.3281, 6.2852,
6.7734, 6.2461, 6.9805, 5.4648, 5.8867, 6.8242, 6.3008, 6.3281, 7.3047,
7.1836, 6.5195, 6.6328, 6.7188, 5.4336, 6.5078, 5.3477, 5.5508, 7.3125,
5.8750, 6.5195, 6.2383, 6.3594, 6.0898, 6.4141, 5.9844, 6.6250, 7.7109,
6.0391, 7.2344, 5.9453, 5.9453, 7.0586, 5.6641, 7.2773, 6.5195, 7.2227,
6.3359, 5.3203, 6.4375, 7.2383, 6.4023, 6.2148, 7.3750, 5.8164, 6.2109,
6.5430, 5.8164, 6.1680, 6.7656, 6.0820, 6.1094, 6.5312, 6.8906, 6.8320,
6.1289, 6.3125, 7.6797, 6.3008, 6.0000, 7.3320, 6.7852, 6.9297, 6.6328,
6.2266, 5.1602, 6.2031, 7.0547, 5.9492, 6.0703, 6.0977, 6.8086, 6.0742,
6.0195, 7.0625, 6.5781, 5.7461, 6.1562, 7.0430, 6.7148, 6.5312, 6.5820,
6.4570, 7.5508, 5.6289, 6.0547, 6.5000, 7.3125, 5.8477, 5.9297, 6.2578,
6.0078, 5.9922, 7.3398, 7.4922, 7.8906, 7.5547, 5.4648, 6.5156, 6.3242,
6.1094, 6.9219, 6.7227, 6.6836, 7.4023, 5.9648, 7.2383, 6.7695, 6.6797,
7.0547, 6.3047, 6.4688, 6.9961, 6.0391, 5.9727, 6.8398, 6.7422, 5.7656,
5.4766, 6.7852, 7.0820, 5.3516, 7.6523, 5.1562, 6.6445, 6.1211, 6.2695,
6.0703, 6.3594, 6.4062, 6.3398, 5.7578, 6.5391, 6.2500, 6.5742, 6.5000,
7.5625, 7.0117, 6.5547, 7.1250, 6.4453, 6.6094, 6.1875, 6.4219, 6.6172,
6.4336, 6.5703, 6.1758, 6.4219, 6.6016, 6.7383, 6.7070, 6.1328, 5.5586,
6.6367, 6.3789, 6.2578, 5.5039, 6.6172, 6.4648, 5.8086, 7.2031, 5.8125,
6.3711, 7.6758, 7.1289, 5.8086, 6.3008, 6.2109, 6.1602, 6.1797, 7.2305,
6.7266, 6.2422, 5.6719, 6.7070, 6.9414, 6.8594, 7.4023, 7.2109, 6.0156,
6.6680, 6.6172, 7.1250, 6.6523, 6.9531, 6.7617, 6.4961, 6.9414, 5.7188,
7.6367, 6.5469, 6.2305, 6.4414, 7.4648, 5.9102, 6.2461, 6.1367, 6.8203,
6.5703, 6.8867, 7.0000, 6.7539, 6.1719, 6.5469, 6.2422, 5.4297, 5.7305,
5.1641, 6.1875, 7.0312, 6.6484, 6.0234, 7.4102, 6.8711, 6.3086, 6.3711,
6.7344, 6.6992, 5.9766, 7.3906, 7.1875, 6.4883, 6.3984, 7.3438, 6.9688,
6.9062, 6.4375, 6.7891, 7.0117, 6.4883, 5.7500, 7.0898, 7.0742, 6.7070,
5.8750, 6.0469, 6.6445, 5.2773, 6.8984, 6.1641, 7.0508, 7.4609, 5.0273,
6.7734, 6.4531, 5.7656, 6.5312, 7.4648, 6.1250, 6.5625, 7.1367, 6.0625,
6.1211, 6.9766, 6.6758, 6.3164, 6.8828, 6.8203, 6.7500, 6.5352, 7.3008,
6.7852, 6.1914, 5.0508, 6.7188, 7.1172, 6.8008, 6.8086, 5.4883, 6.9180,
6.5742, 6.1719, 7.0469, 7.1523, 5.9492, 5.8594, 6.8320, 6.1719, 6.2031,
6.8398, 7.3008, 6.6289, 6.4922, 6.0000, 5.4766, 6.3320, 6.5117, 6.2812,
7.5742, 6.3516, 7.0039, 6.4570, 7.1523, 7.6289, 6.2578, 7.1875, 6.4844,
5.7930, 6.7070, 7.5508, 7.1797, 6.0430, 6.8711, 6.5742, 7.5781, 6.4766,
6.5391, 6.9453, 6.1992, 6.6367, 6.2812, 6.0234, 6.6953, 7.0312, 6.2031,
6.5625, 6.6719, 6.1719, 6.5586, 5.7031, 7.4609, 6.6211, 7.7227, 6.9141,
6.0469, 6.2500, 5.3828, 6.0078, 5.8164, 5.8867, 6.1523, 6.6523, 6.6953,
7.3125, 6.4844, 5.9570, 5.9531, 6.2109, 5.5039, 6.5117, 6.8203, 6.6133,
6.4766, 5.9297, 7.1445, 7.1914, 6.0117, 6.8281, 6.7422, 6.1328, 6.9805,
6.5625, 6.9180, 7.1133, 7.3359, 5.7617, 5.8711, 6.4961, 6.5859, 6.2422,
6.5273, 6.7461, 6.6992, 6.7695, 6.6289, 5.9453, 5.9805, 7.1172, 6.6719,
6.0039, 7.6875, 6.7812, 7.8359, 6.9531, 7.4336, 7.6602, 6.8164, 7.3945,
7.1602, 6.8789, 5.0078, 6.0547, 6.8086, 6.7070, 6.4688, 6.4492, 6.6172,
5.5625, 6.6914, 6.4297, 5.7461, 5.3359, 6.8750, 6.4609, 7.4062, 5.2070,
6.0820, 6.7383, 6.5703, 6.1797, 6.7070, 6.5977, 5.9961, 6.6328, 6.9375,
6.3906, 6.6484, 4.9609, 6.6445, 6.5898, 7.1875, 7.5195, 6.7969, 6.1367,
6.8906, 7.4297, 6.3633, 6.0508, 6.5000, 6.4648, 6.7539, 6.7109, 5.8086,
6.6016, 7.1133, 4.8672, 6.6367, 6.1641, 5.1758, 6.9453, 6.3242, 7.0664,
6.4805, 6.3516, 6.7383, 8.4688, 6.7305, 5.9844, 6.5938, 7.2969, 6.5977,
7.5898, 6.2969, 6.8672, 6.6680, 7.1289, 6.6875, 5.4258, 8.1875, 8.0391,
7.7969, 6.6445, 7.0703, 7.3359, 6.9805, 6.6328, 6.5352, 6.2422, 5.5820,
6.8633, 6.8047, 6.5703, 6.0117, 6.7539, 7.1719, 6.8438, 7.3633, 6.6016,
7.2070, 6.4727, 5.8008, 7.4062, 7.4805, 6.6445, 5.9023, 6.3984, 6.9961,
6.6680, 6.8242, 6.7148, 6.6172, 6.9727, 6.8320, 5.9766, 6.6133, 5.5977,
6.7773, 7.3906, 6.9219, 7.0781, 6.6914, 5.7539, 6.7969, 6.8008, 5.8047,
7.1055, 6.4961, 6.0352, 5.6211, 7.4414, 7.0703, 6.1172, 6.7461, 6.4492,
7.7148, 6.4258, 6.0039, 6.5156, 7.2188, 7.4531, 7.4844, 7.5938, 7.4023,
6.7617, 6.0078, 6.3320, 5.8906, 7.5977, 5.6523, 6.7734, 6.3008, 5.2227,
7.1719, 7.1289, 6.6602, 5.4609, 7.0312, 6.0820, 6.1719, 6.0000, 6.5547,
6.6328, 7.0547, 7.0859, 6.2656, 5.5234, 6.0273, 6.7891, 7.1875, 6.9531,
6.8203, 6.3516, 6.1172, 6.4648, 6.9180, 7.3906, 6.2812, 5.7109, 6.1484,
6.9102, 6.8711, 7.0156, 6.1445, 5.8867, 6.3828, 5.9961, 6.6914, 6.7891,
7.0820, 6.6719, 6.9297, 6.3750, 6.7578, 6.4883, 6.2227, 6.2305, 6.0508,
6.6484, 5.7578, 7.2070, 7.2383, 6.9375, 7.2578, 6.5312, 6.0312, 6.7930,
6.2578, 7.0625, 7.2148, 6.4961, 7.0703, 6.4727, 7.3906]).to(torch.float16)
def test_exllamav2(self):
from auto_gptq.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
group_size = 128
m = 1
k = 1024
n = 1024
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=k,
outfeatures=n,
bias=False,
)
self.assertTrue(isinstance(linear, QuantLinear))
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.eval()
linear = linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
with torch.no_grad():
res = linear(inp)[0][0]
reference = self.REFERENCE.to(device)
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
def test_generation_no_act_order(self):
prompt = "I am in Paris and"
device = torch.device("cuda:0")
# Reference generated with the cuda-old kernel
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
model_basename = "model"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, model_basename=model_basename)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
predicted_text = tokenizer.decode(res[0])
self.assertEqual(predicted_text, reference_output)
def test_generation_with_act_order(self):
prompt = "I am in Paris and"
device = torch.device("cuda:0")
# Reference generated with the cuda-old kernel
reference_output = "<s> I am in Paris and it is a beautiful day. I am sitting in a café, drinking coffee and writing this book. I am surrounded by the sights and sounds of the city, and I am filled with a sense of contentment and gratitude.\n\nI am grateful for the opportunity to live and"
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
predicted_text = tokenizer.decode(res[0])
self.assertEqual(predicted_text, reference_output)
def test_exllama_buffer_size(self):
# prompt = "I'm in Paris and" * 450
prompt = "I'm in Paris and" * 1000
device = torch.device("cuda:0")
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
self.assertTrue(inp["input_ids"].shape[1] > 2048) # 2048 is the default max_input_length for LLama
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)