Compare commits
No commits in common. "main" and "v0.1.0" have entirely different histories.
108 changed files with 1067 additions and 17780 deletions
33
.github/ISSUE_TEMPLATE/bug_report.md
vendored
33
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -1,33 +0,0 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Hardware details**
|
||||
Information about CPU and GPU, such as RAM, number, etc.
|
||||
|
||||
**Software version**
|
||||
Version of relevant software such as operation system, cuda toolkit, python, auto-gptq, pytorch, transformers, accelerate, etc.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
10
.github/ISSUE_TEMPLATE/custom.md
vendored
10
.github/ISSUE_TEMPLATE/custom.md
vendored
|
@ -1,10 +0,0 @@
|
|||
---
|
||||
name: Custom issue template
|
||||
about: Describe this issue template's purpose here.
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
|
@ -1,20 +0,0 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: "[FEATURE]"
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
69
.github/workflows/build_wheels_cuda.yml
vendored
69
.github/workflows/build_wheels_cuda.yml
vendored
|
@ -1,69 +0,0 @@
|
|||
name: Build AutoGPTQ Wheels with CUDA
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
if: ${{ github.repository_owner == 'PanQiWei' }}
|
||||
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA ${{ matrix.cuda }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-20.04, windows-latest]
|
||||
pyver: ["3.8", "3.9", "3.10", "3.11"]
|
||||
cuda: ["11.7", "11.8"]
|
||||
defaults:
|
||||
run:
|
||||
shell: pwsh
|
||||
env:
|
||||
CUDA_VERSION: ${{ matrix.cuda }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.pyver }}
|
||||
|
||||
- name: Setup Miniconda
|
||||
uses: conda-incubator/setup-miniconda@v2.2.0
|
||||
with:
|
||||
activate-environment: "build"
|
||||
python-version: ${{ matrix.pyver }}
|
||||
mamba-version: "*"
|
||||
use-mamba: false
|
||||
channels: conda-forge,defaults
|
||||
channel-priority: true
|
||||
add-pip-as-python-dependency: true
|
||||
auto-activate-base: false
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
|
||||
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
|
||||
python -m pip install --upgrade build setuptools wheel ninja
|
||||
|
||||
- name: Build Wheel
|
||||
run: |
|
||||
$env:CUDA_PATH = $env:CONDA_PREFIX
|
||||
$env:CUDA_HOME = $env:CONDA_PREFIX
|
||||
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
|
||||
|
||||
# TODO: remove this
|
||||
if (!$IsLinux) {$env:INCLUDE_EXLLAMA_KERNELS = 0}
|
||||
|
||||
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
|
||||
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
|
||||
python setup.py sdist bdist_wheel
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: runner.os == 'Linux'
|
||||
with:
|
||||
name: 'linux-cuda-wheels'
|
||||
path: ./dist/*.whl
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: runner.os == 'Windows'
|
||||
with:
|
||||
name: 'windows-cuda-wheels'
|
||||
path: ./dist/*.whl
|
74
.github/workflows/build_wheels_pypi.yml
vendored
74
.github/workflows/build_wheels_pypi.yml
vendored
|
@ -1,74 +0,0 @@
|
|||
name: Build AutoGPTQ Wheels for PyPI with CUDA
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
if: ${{ github.repository_owner == 'PanQiWei' }}
|
||||
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA 11.7
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-20.04, windows-latest]
|
||||
pyver: ["3.8", "3.9", "3.10", "3.11"]
|
||||
defaults:
|
||||
run:
|
||||
shell: pwsh
|
||||
env:
|
||||
CUDA_VERSION: "11.7"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.pyver }}
|
||||
|
||||
- name: Setup Miniconda
|
||||
uses: conda-incubator/setup-miniconda@v2.2.0
|
||||
with:
|
||||
activate-environment: "build"
|
||||
python-version: ${{ matrix.pyver }}
|
||||
mamba-version: "*"
|
||||
use-mamba: false
|
||||
channels: conda-forge,defaults
|
||||
channel-priority: true
|
||||
add-pip-as-python-dependency: true
|
||||
auto-activate-base: false
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
|
||||
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
|
||||
python -m pip install --upgrade build setuptools wheel ninja
|
||||
|
||||
- name: Build Wheel
|
||||
run: |
|
||||
$env:CUDA_PATH = $env:CONDA_PREFIX
|
||||
$env:CUDA_HOME = $env:CONDA_PREFIX
|
||||
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
|
||||
|
||||
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
|
||||
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
|
||||
|
||||
$env:PYPI_RELEASE = "1"
|
||||
|
||||
echo "CUDA_PATH:"
|
||||
echo $env:CUDA_PATH
|
||||
|
||||
echo "PYPI_RELEASE:"
|
||||
echo $env:PYPI_RELEASE
|
||||
|
||||
python setup.py sdist bdist_wheel
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: runner.os == 'Linux'
|
||||
with:
|
||||
name: 'linux-cuda-wheels'
|
||||
path: ./dist/*.whl
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: runner.os == 'Windows'
|
||||
with:
|
||||
name: 'windows-cuda-wheels'
|
||||
path: ./dist/*.whl
|
103
.github/workflows/build_wheels_rocm.yml
vendored
103
.github/workflows/build_wheels_rocm.yml
vendored
|
@ -1,103 +0,0 @@
|
|||
name: Build AutoGPTQ Wheels with ROCm
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
if: ${{ github.repository_owner == 'PanQiWei' }}
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-20.04]
|
||||
python: ["3.8", "3.9", "3.10", "3.11"]
|
||||
rocm: ["5.4.2"] # , "5.5", "5.6"]
|
||||
|
||||
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Free disk space
|
||||
run: |
|
||||
df -h
|
||||
echo "Removing large packages"
|
||||
sudo apt-get remove -y '^dotnet-.*'
|
||||
sudo apt-get remove -y 'php.*'
|
||||
sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel
|
||||
df -h
|
||||
sudo apt-get autoremove -y >/dev/null 2>&1
|
||||
sudo apt-get clean
|
||||
sudo apt-get autoremove -y >/dev/null 2>&1
|
||||
sudo apt-get autoclean -y >/dev/null 2>&1
|
||||
df -h
|
||||
echo "https://github.com/actions/virtual-environments/issues/709"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
df -h
|
||||
echo "remove big /usr/local"
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
|
||||
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
|
||||
sudo rm -rf /usr/share/swift > /dev/null 2>&1
|
||||
df -h
|
||||
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Setup Miniconda
|
||||
uses: conda-incubator/setup-miniconda@v2.2.0
|
||||
with:
|
||||
activate-environment: "build"
|
||||
python-version: ${{ matrix.python }}
|
||||
mamba-version: "*"
|
||||
use-mamba: false
|
||||
channels: conda-forge,defaults
|
||||
channel-priority: true
|
||||
add-pip-as-python-dependency: true
|
||||
auto-activate-base: false
|
||||
|
||||
- name: Set up environment
|
||||
run: |
|
||||
echo "Using python:"
|
||||
python --version
|
||||
which python
|
||||
|
||||
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
|
||||
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
|
||||
elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then
|
||||
export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb
|
||||
else
|
||||
export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb
|
||||
fi
|
||||
|
||||
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
|
||||
sudo dpkg -i $ROCM_DL_FILE
|
||||
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
|
||||
|
||||
python -m pip install --upgrade build setuptools wheel ninja
|
||||
python -m pip install torch --index-url https://download.pytorch.org/whl/rocm${{ matrix.rocm }}
|
||||
|
||||
- name: Build wheels
|
||||
run: |
|
||||
echo "Using python for build:"
|
||||
python --version
|
||||
which python
|
||||
|
||||
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: 'linux-rocm-wheels'
|
||||
path: ./dist/*.whl
|
160
.gitignore
vendored
160
.gitignore
vendored
|
@ -1,160 +0,0 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
184
README.md
184
README.md
|
@ -12,101 +12,84 @@
|
|||
<p>
|
||||
<b>English</b> |
|
||||
<a href="https://github.com/PanQiWei/AutoGPTQ/blob/main/README_zh.md">中文</a>
|
||||
</p>
|
||||
<p>
|
||||
</h4>
|
||||
|
||||
|
||||
## The path to v1.0.0
|
||||
|
||||
Hi, fellow community members, long time no see! I'm sorry that I haven't been able to update this project more frequently due to personal reasons during this period. The past few weeks have been huge in terms of my career plans. Not long ago, I officially bid farewell to the startup team that I joined for two years after graduation. I'm very grateful to the leaders and colleagues of the team for their trust and guidance, which enabled me to grow rapidly in two years; at the same time, I'm also really grateful to the team for allowing me to use the internal A100 GPU server cluster free of charge since the start of the AutoGPTQ project to complete various experiments and performance evaluations. (Of course, it can no longer be used in the future, so **it will mean a lot to me if there will be new hardware sponsorship!**) In the past two years, I have served as an AI engineer in this team, responsible for the LLM based dialogue system's architecture design and develop. We had successfully launched a product called gemsouls, but unfortunately it has ceased operations. Now, the team is about to launch a new product called [modelize](https://www.beta.modelize.ai/), which is **a LLM-native AI agent platform, where users can use multiple AI agents to build a highly automated team, allowing them to interact with each other in the workflow, collaborate to complete complex projects efficiently.**
|
||||
|
||||
Getting back to the topic, I'm very excited to see that in the past few months, research on optimizing the inference performance of LLMs has made tremendous progress. Now we can not only complete the inference of LLMs on high-end GPUs efficiently, but also on CPUs and even edge devices. A series of technological advancements make me eager to make more contributions to the open source community. Therefore, I will first use about four weeks to gradually update AutoGPTQ to the v1.0.0 official version. During this period, there will also be 2~3 minor versions are released to allow users to experience performance optimization and new features timely. In my vision, **by the time v1.0.0 is officially released, AutoGPTQ will be able to serve as an extendable and flexible quantization backend that supports all GPTQ-like methods and automatically quantize LLMs written by Pytorch**. I detailed the development plan in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/348), feel free to drop in there for discussion and give your suggestions!
|
||||
|
||||
## News or Update
|
||||
|
||||
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
|
||||
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
|
||||
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
|
||||
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
|
||||
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
|
||||
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
|
||||
- 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub.
|
||||
- 2023-05-04 - (Update) - Support using faster cuda kernel when `not desc_act or group_size == -1`.
|
||||
- 2023-04-29 - (Update) - Support loading quantized model from arbitrary quantize_config and model_basename.
|
||||
- 2023-04-28 - (Update) - Support CPU offload and quantize/inference on multiple devices, support `gpt2` type models.
|
||||
|
||||
*For more histories please turn to [here](docs/NEWS_OR_UPDATE.md)*
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
### Inference Speed
|
||||
> The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
|
||||
>
|
||||
> The quantized model is loaded using the setup that can gain the fastest inference speed.
|
||||
|
||||
| model | GPU | num_beams | fp16 | gptq-int4 |
|
||||
|---------------|---------------|-----------|-------|-----------|
|
||||
| llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
|
||||
| llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
|
||||
| moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
|
||||
| moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
|
||||
| moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
|
||||
| moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
|
||||
| gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
|
||||
| gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
|
||||
|
||||
|
||||
### Perplexity
|
||||
For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#result) and [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#gptq-vs-bitsandbytes)
|
||||
|
||||
## Installation
|
||||
|
||||
### Quick Installation
|
||||
You can install the latest stable release of AutoGPTQ from pip with pre-built wheels compatible with PyTorch 2.0.1:
|
||||
You can install the latest stable release of AutoGPTQ from pip:
|
||||
```shell
|
||||
pip install auto-gptq
|
||||
```
|
||||
#### disable cuda extensions
|
||||
By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using:
|
||||
```shell
|
||||
BUILD_CUDA_EXT=0 pip install auto-gptq
|
||||
```
|
||||
And to make sure `quant_cuda` is not ever in your virtual environment, run:
|
||||
```shell
|
||||
pip uninstall quant_cuda -y
|
||||
```
|
||||
#### to support LLaMa model
|
||||
For some people want to try LLaMa and whose `transformers` version not meet the newest one that supports it, using:
|
||||
```shell
|
||||
pip install auto-gptq[llama]
|
||||
```
|
||||
#### to support triton speedup
|
||||
To integrate with `triton`, using:
|
||||
> warning: currently triton only supports linux; 3-bit quantization is not supported when using triton
|
||||
|
||||
* For CUDA 11.7: `pip install auto-gptq`
|
||||
* For CUDA 11.8: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
|
||||
* For RoCm 5.4.2: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
|
||||
|
||||
**Warning:** These wheels are not expected to work on PyTorch nightly. Please install AutoGPTQ from source when using PyTorch nightly.
|
||||
|
||||
AutoGPTQ can be installed with the Triton dependency with `pip install auto-gptq[triton]` in order to be able to use the Triton backend (currently only supports linux, no 3-bits quantization).
|
||||
```shell
|
||||
pip install auto-gptq[triton]
|
||||
```
|
||||
|
||||
### Install from source
|
||||
|
||||
Clone the source code:
|
||||
```shell
|
||||
git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
|
||||
```
|
||||
Then, install from source:
|
||||
```shell
|
||||
pip install -v .
|
||||
pip install .
|
||||
```
|
||||
You can set `BUILD_CUDA_EXT=0` to disable pytorch extension building, but this is **strongly discouraged** as AutoGPTQ then falls back on a slow python implementation.
|
||||
Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch extension building.
|
||||
|
||||
To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
|
||||
Use `.[llama]` if you want to try LLaMa model.
|
||||
|
||||
```
|
||||
ROCM_VERSION=5.6 pip install -v .
|
||||
```
|
||||
Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
|
||||
|
||||
For RoCm systems, the packages `rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev` are required to build.
|
||||
|
||||
## Quick Tour
|
||||
## Supported Models
|
||||
Currently, `auto_gptq` supports: `bloom`, `gpt2`, `gpt_neox`, `gptj`, `llama`, `moss` and `opt`; more Transformer models will come soon!
|
||||
|
||||
### Quantization and Inference
|
||||
> warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
|
||||
## Supported Evaluation Tasks
|
||||
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
|
||||
|
||||
Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
|
||||
## Usage
|
||||
|
||||
**Here are [tutorials](docs/tutorial)(continue updating...) for using `auto-gptq`, it's highly recommended for newcomers to read them first before trying example scripts.**
|
||||
|
||||
### Basic
|
||||
> warning: this is just a show case of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, thus may not performs as well as expected in LLMs.
|
||||
|
||||
Below is an example for the simplest use of `auto_gptq`:
|
||||
```python
|
||||
from transformers import AutoTokenizer, TextGenerationPipeline
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
pretrained_model_dir = "facebook/opt-125m"
|
||||
quantized_model_dir = "opt-125m-4bit"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||
examples = [
|
||||
tokenizer(
|
||||
|
@ -117,14 +100,13 @@ examples = [
|
|||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4, # quantize model to 4-bit
|
||||
group_size=128, # it is recommended to set the value to 128
|
||||
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||||
)
|
||||
|
||||
# load un-quantized model, by default, the model will always be loaded into CPU memory
|
||||
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||
|
||||
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
|
||||
model.quantize(examples)
|
||||
model.quantize(examples, use_triton=False)
|
||||
|
||||
# save quantized model
|
||||
model.save_quantized(quantized_model_dir)
|
||||
|
@ -132,28 +114,11 @@ model.save_quantized(quantized_model_dir)
|
|||
# save quantized model using safetensors
|
||||
model.save_quantized(quantized_model_dir, use_safetensors=True)
|
||||
|
||||
# push quantized model to Hugging Face Hub.
|
||||
# to use use_auth_token=True, Login first via huggingface-cli login.
|
||||
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
|
||||
# (uncomment the following three lines to enable this feature)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# alternatively you can save and push at the same time
|
||||
# (uncomment the following three lines to enable this feature)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
|
||||
|
||||
# download quantized model from Hugging Face Hub and load to the first GPU
|
||||
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
|
||||
|
||||
# inference with model.generate
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to("cuda:0"))[0]))
|
||||
|
||||
# or you can also use pipeline
|
||||
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
@ -163,10 +128,7 @@ print(pipeline("auto-gptq is")[0]["generated_text"])
|
|||
For more advanced features of model quantization, please reference to [this script](examples/quantization/quant_with_alpaca.py)
|
||||
|
||||
### Customize Model
|
||||
<details>
|
||||
|
||||
<summary>Below is an example to extend `auto_gptq` to support `OPT` model, as you will see, it's very easy:</summary>
|
||||
|
||||
Below is an example to extend `auto_gptq` to support `OPT` model, as you will see, it's very easy:
|
||||
```python
|
||||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
@ -192,17 +154,12 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
```
|
||||
After this, you can use `OPTGPTQForCausalLM.from_pretrained` and other methods as shown in Basic.
|
||||
|
||||
</details>
|
||||
|
||||
### Evaluation on Downstream Tasks
|
||||
You can use tasks defined in `auto_gptq.eval_tasks` to evaluate model's performance on specific down-stream task before and after quantization.
|
||||
|
||||
The predefined tasks support all causal-language-models implemented in [🤗 transformers](https://github.com/huggingface/transformers) and in this project.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Below is an example to evaluate `EleutherAI/gpt-j-6b` on sequence-classification task using `cardiffnlp/tweet_sentiment_multilingual` dataset:</summary>
|
||||
|
||||
Below is an example to evaluate `EleutherAI/gpt-j-6b` on sequence-classification task using `cardiffnlp/tweet_sentiment_multilingual` dataset:
|
||||
```python
|
||||
from functools import partial
|
||||
|
||||
|
@ -278,46 +235,9 @@ print(
|
|||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Learn More
|
||||
[tutorials](docs/tutorial) provide step-by-step guidance to integrate `auto_gptq` with your own project and some best practice principles.
|
||||
|
||||
[examples](examples/README.md) provide plenty of example scripts to use `auto_gptq` in different ways.
|
||||
|
||||
## Supported Models
|
||||
|
||||
> you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`.
|
||||
>
|
||||
> for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`.
|
||||
|
||||
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
|
||||
|------------------------------------|--------------|-----------|-----------|---------------|-------------------------------------------------------------------------------------------------|
|
||||
| bloom | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt2 | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt_neox | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| gptj | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| moss | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| opt | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt_bigcode | ✅ | ✅ | ✅ | ✅ | |
|
||||
| codegen | ✅ | ✅ | ✅ | ✅ | |
|
||||
| falcon(RefinedWebModel/RefinedWeb) | ✅ | ✅ | ✅ | ✅ | |
|
||||
|
||||
## Supported Evaluation Tasks
|
||||
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
|
||||
|
||||
## Running tests
|
||||
|
||||
Tests can be run with:
|
||||
|
||||
```
|
||||
pytest tests/ -s
|
||||
```
|
||||
### More Examples
|
||||
For more examples, please turn to [examples](examples/README.md)
|
||||
|
||||
## Acknowledgement
|
||||
- Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
|
||||
- Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).
|
||||
|
||||
|
||||
[](https://star-history.com/#PanQiWei/AutoGPTQ&Date)
|
||||
|
|
155
README_zh.md
155
README_zh.md
|
@ -12,70 +12,37 @@
|
|||
<p>
|
||||
<a href="https://github.com/PanQiWei/AutoGPTQ/blob/main/README.md">English</a> |
|
||||
<b>中文</b>
|
||||
</p>
|
||||
<p>
|
||||
</h4>
|
||||
|
||||
## 通向 v1.0.0 之路
|
||||
|
||||
嗨,社区的伙伴们,好久不见!很抱歉这段时间由于个人原因,我没能以较高的频率来更新这个项目。过去几周对我的职业生涯规划而言意义重大。在不久前,我正式告别了毕业后便加入两年之久的创业团队,非常感谢团队的领导和同事们给予我的信任与指导,让我能够在两年时间里飞速地成长;同时也十分感激团队允许我自 AutoGPTQ 项目创立以来一直无偿使用内部的 A100 GPU 服务器集群以完成各项实验与性能测评。(当然今后是无法继续使用了,因此**若有新的硬件赞助我将感激不尽**!)过去的两年里,我在这个团队中担任算法工程师的角色,负责基于大语言模型的对话系统架构设计与开发,我们曾成功推出一款名为 gemsouls 的产品,但不幸的是它已经停止运营。而现在,这个团队即将推出一款名为 [modelize](https://www.beta.modelize.ai/) 的新产品,**这是一个大模型原生的 AI 智能体平台,用户可以使用多个 AI 智能体搭建一个高度自动化的团队,让它们在工作流中相互合作,高效完成复杂的项目。**
|
||||
|
||||
话归正题,我非常兴奋地看到,在过去几个月的时间里,针对大语言模型推理性能优化的研究取得了巨大的进展,如今我们不仅能够在高端显卡上完成大语言模型的推理,甚至在 CPU 和边缘设备上都可以轻松运行大语言模型。一系列的技术进步,让我同样迫不及待地在开源社区上做出更多的贡献,因此,首先,我将用约四周的时间将 AutoGPTQ 迭代至 v1.0.0 正式版本,在此期间,也会有 2~3 个小版本发布以让用户能够及时体验性能优化和新特性。在我的愿景里,**到 v1.0.0 版本正式发布时,AutoGPTQ 将能够作为一个灵活可拓展的、支持所有 GPTQ-like 方法的量化后端,自动地完成各种基于 Pytorch 编写的大语言模型的量化工作**。我在[这里](https://github.com/PanQiWei/AutoGPTQ/issues/348)详细介绍了开发计划,欢迎移步至此进行讨论并给出你们的建议!
|
||||
|
||||
## 新闻或更新
|
||||
|
||||
- 2023-08-23 - (新闻) - 🤗 Transformers、optimum 和 peft 完成了对 `auto-gptq` 的集成,现在使用 GPTQ 模型进行推理和训练将变得更容易!阅读 [这篇博客](https://huggingface.co/blog/gptq-integration) 和相关资源以了解更多细节!
|
||||
- 2023-08-21 - (新闻) - 通义千问团队发布了基于 `auto-gptq` 的 Qwen-7B 4bit 量化版本模型,并提供了[详尽的测评结果](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
|
||||
- 2023-08-06 - (更新) - 支持 exllama 的 q4 CUDA 算子使得 int4 量化模型能够获得至少1.3倍的推理速度提升.
|
||||
- 2023-08-04 - (更新) - 支持 RoCm 使得 AMD GPU 的用户能够使用 auto-gptq 的 CUDA 拓展.
|
||||
- 2023-07-26 - (更新) - 一个优雅的 [PPL 测评脚本](examples/benchmark/perplexity.py)以获得可以与诸如 `llama.cpp` 等代码库进行公平比较的结果。
|
||||
- 2023-06-05 - (更新) - 集成 🤗 peft 来使用 gptq 量化过的模型训练适应层,支持 LoRA,AdaLoRA,AdaptionPrompt 等。
|
||||
- 2023-05-30 - (更新) - 支持从 🤗 Hub 下载量化好的模型或上次量化好的模型到 🤗 Hub。
|
||||
- 2023-05-04 - (更新) - 支持在 `not desc_act or group_size == -1` 的情况下使用更快的 cuda 算子。
|
||||
- 2023-04-29 - (更新) - 支持从指定的模型权重文件名或量化配置(quantize_config)加载量化过的模型。
|
||||
- 2023-04-28 - (更新) - 支持 CPU 分载权重和在多设备上执行模型量化或推理, 支持 `gpt2` 类型的模型。
|
||||
|
||||
*获取更多的历史信息,请转至[这里](docs/NEWS_OR_UPDATE.md)*
|
||||
|
||||
## 性能对比
|
||||
|
||||
### 推理速度
|
||||
> 以下结果通过[这个脚本](examples/benchmark/generation_speed.py)生成,文本输入的 batch size 为1,解码策略为 beam search 并且强制模型生成512个 token,速度的计量单位为 tokens/s(越大越好)。
|
||||
>
|
||||
> 量化模型通过能够最大化推理速度的方式加载。
|
||||
|
||||
| model | GPU | num_beams | fp16 | gptq-int4 |
|
||||
|---------------|---------------|-----------|-------|-----------|
|
||||
| llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
|
||||
| llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
|
||||
| moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
|
||||
| moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
|
||||
| moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
|
||||
| moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
|
||||
| gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
|
||||
| gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
|
||||
|
||||
|
||||
### 困惑度(PPL)
|
||||
对于困惑度的对比, 你可以参考 [这里](https://github.com/qwopqwop200/GPTQ-for-LLaMa#result) 和 [这里](https://github.com/qwopqwop200/GPTQ-for-LLaMa#gptq-vs-bitsandbytes)
|
||||
|
||||
## 安装
|
||||
|
||||
### 快速安装
|
||||
你可以通过 pip 来安装与 PyTorch 2.0.1 相兼容的最新稳定版本的 AutoGPTQ 的预构建轮子文件:
|
||||
|
||||
* 对于 CUDA 11.7: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/`
|
||||
* 对于 CUDA 11.8: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
|
||||
* 对于 RoCm 5.4.2: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
|
||||
|
||||
**警告:** 预构建的轮子文件不一定在 PyTorch 的 nightly 版本上有效。如果要使用 PyTorch 的 nightly 版本,请从源码安装 AutoGPTQ。
|
||||
|
||||
你可以通过 pip 来安装 AutoGPTQ 当前最新的稳定版本:
|
||||
```shell
|
||||
pip install auto-gptq
|
||||
```
|
||||
#### 取消 cuda 拓展的安装
|
||||
默认情况下,在 `torch` 和 `cuda` 已经于你的机器上被安装时,cuda 拓展将被自动安装,如果你不想要这些拓展的话,采用以下安装命令:
|
||||
```shell
|
||||
BUILD_CUDA_EXT=0 pip install auto-gptq
|
||||
```
|
||||
同时为确保该拓展——`autogptq_cuda` 不再存在于你的虚拟环境,执行以下命令:
|
||||
同时为确保该拓展——`quant_cuda` 不再存在于你的虚拟环境,执行以下命令:
|
||||
```shell
|
||||
pip uninstall autogptq_cuda -y
|
||||
pip uninstall quant_cuda -y
|
||||
```
|
||||
#### 支持使用 LLaMa 模型
|
||||
若想要尝试 LLaMa 模型,但 `transformers` 版本不为支持该模型的最新版本,使用以下命令:
|
||||
```shell
|
||||
pip install auto-gptq[llama]
|
||||
```
|
||||
|
||||
#### 支持使用 triton 加速
|
||||
若想使用 `triton` 加速模型推理,使用以下命令:
|
||||
> 警告:目前 triton 仅支持 linux 操作系统;当使用 triton 时 3-bit 数值类型的量化将不被支持
|
||||
|
@ -85,9 +52,6 @@ pip install auto-gptq[triton]
|
|||
```
|
||||
|
||||
### 从源码安装
|
||||
<details>
|
||||
<summary>点击以查看详情</summary>
|
||||
|
||||
克隆源码:
|
||||
```shell
|
||||
git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
|
||||
|
@ -98,24 +62,25 @@ pip install .
|
|||
```
|
||||
正如在快速安装一节,你可以使用 `BUILD_CUDA_EXT=0` 来取消构建 cuda 拓展。
|
||||
|
||||
如果你想要使用 LLaMa 模型,请使用 `.[llama]`。
|
||||
|
||||
如果你想要使用 triton 加速且其能够被你的操作系统所支持,请使用 `.[triton]`。
|
||||
|
||||
对应 AMD GPUs,为了从源码安装以支持 RoCm,请设置 `ROCM_VERSION` 环境变量。同时通过设置 `PYTORCH_ROCM_ARCH` ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)) 可提升编译速度,例如:对于 MI200 系列设备,该变量可设为 `gfx90a`。例子:
|
||||
|
||||
```
|
||||
ROCM_VERSION=5.6 pip install .
|
||||
```
|
||||
## 支持的模型
|
||||
目前, `auto_gptq` 支持以下模型: `bloom`, `gpt2`, `gpt_neox`, `gptj`, `llama`, `moss` 和 `opt`;更多的 Transformer 模型即将到来!
|
||||
|
||||
对于 RoCm 系统,在从源码安装时额外需要提前安装以下包:`rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev`。
|
||||
## 支持的评估任务
|
||||
目前, `auto_gptq` 支持以下评估任务: `LanguageModelingTask`, `SequenceClassificationTask` 和 `TextSummarizationTask`;更多的评估任务即将到来!
|
||||
|
||||
</details>
|
||||
## 用法
|
||||
|
||||
## 快速开始
|
||||
**对于初次使用者,强烈建议在运行示例脚本前先阅读[教程](docs/tutorial)(持续更新中……)**
|
||||
|
||||
### 量化和推理
|
||||
### 基本用法
|
||||
> 警告:这里仅是对 AutoGPTQ 中基本接口的用法展示,只使用了一条文本来量化一个特别小的模型,因此其结果的表现可能不如在大模型上执行量化后预期的那样好。
|
||||
|
||||
以下展示了使用 `auto_gptq` 进行量化和推理的最简单用法:
|
||||
以下是 `auto_gptq` 的最简单用法示例:
|
||||
```python
|
||||
from transformers import AutoTokenizer, TextGenerationPipeline
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
|
@ -135,14 +100,13 @@ examples = [
|
|||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4, # 将模型量化为 4-bit 数值类型
|
||||
group_size=128, # 一般推荐将此参数的值设置为 128
|
||||
desc_act=False, # 设为 False 可以显著提升推理速度,但是 ppl 可能会轻微地变差
|
||||
)
|
||||
|
||||
# 加载未量化的模型,默认情况下,模型总是会被加载到 CPU 内存中
|
||||
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||
|
||||
# 量化模型, 样本的数据类型应该为 List[Dict],其中字典的键有且仅有 input_ids 和 attention_mask
|
||||
model.quantize(examples)
|
||||
model.quantize(examples, use_triton=False)
|
||||
|
||||
# 保存量化好的模型
|
||||
model.save_quantized(quantized_model_dir)
|
||||
|
@ -150,28 +114,11 @@ model.save_quantized(quantized_model_dir)
|
|||
# 使用 safetensors 保存量化好的模型
|
||||
model.save_quantized(quantized_model_dir, use_safetensors=True)
|
||||
|
||||
# 将量化好的模型直接上传至 Hugging Face Hub
|
||||
# 当使用 use_auth_token=True 时, 确保你已经首先使用 huggingface-cli login 进行了登录
|
||||
# 或者可以使用 use_auth_token="hf_xxxxxxx" 来显式地添加账户认证 token
|
||||
# (取消下面三行代码的注释来使用该功能)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# 或者你也可以同时将量化好的模型保存到本地并上传至 Hugging Face Hub
|
||||
# (取消下面三行代码的注释来使用该功能)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# 加载量化好的模型到能被识别到的第一块显卡中
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
|
||||
|
||||
# 从 Hugging Face Hub 下载量化好的模型并加载到能被识别到的第一块显卡中
|
||||
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
|
||||
|
||||
# 使用 model.generate 执行推理
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to("cuda:0"))[0]))
|
||||
|
||||
# 或者使用 TextGenerationPipeline
|
||||
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
@ -181,11 +128,7 @@ print(pipeline("auto-gptq is")[0]["generated_text"])
|
|||
参考 [此样例脚本](examples/quantization/quant_with_alpaca.py) 以了解进阶的用法。
|
||||
|
||||
### 自定义模型
|
||||
|
||||
<details>
|
||||
|
||||
<summary>以下展示了如何拓展 `auto_gptq` 以支持 `OPT` 模型,如你所见,这非常简单:</summary>
|
||||
|
||||
以下展示了如何拓展 `auto_gptq` 以支持 `OPT` 模型,如你所见,这非常简单:
|
||||
```python
|
||||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
@ -211,18 +154,12 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
```
|
||||
然后, 你就可以像在基本用法一节中展示的那样使用 `OPTGPTQForCausalLM.from_pretrained` 和其他方法。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### 在下游任务上执行评估
|
||||
你可以使用在 `auto_gptq.eval_tasks` 中定义的任务来评估量化前后的模型在某个特定下游任务上的表现。
|
||||
|
||||
这些预定义的模型支持所有在 [🤗 transformers](https://github.com/huggingface/transformers)和本项目中被实现了的 causal-language-models。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>以下是使用 `cardiffnlp/tweet_sentiment_multilingual` 数据集在序列分类(文本分类)任务上评估 `EleutherAI/gpt-j-6b` 模型的示例:</summary>
|
||||
|
||||
以下是使用 `cardiffnlp/tweet_sentiment_multilingual` 数据集在序列分类(文本分类)任务上评估 `EleutherAI/gpt-j-6b` 模型的示例:
|
||||
```python
|
||||
from functools import partial
|
||||
|
||||
|
@ -298,37 +235,9 @@ print(
|
|||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 了解更多
|
||||
[教程](docs/tutorial) 提供了将 `auto_gptq` 集成到你的项目中的手把手指导和最佳实践准则。
|
||||
|
||||
[示例](examples/README.md) 提供了大量示例脚本以将 `auto_gptq` 用于不同领域。
|
||||
|
||||
## 支持的模型
|
||||
|
||||
> 你可以使用 `model.config.model_type` 来对照下表以检查你正在使用的一个模型是否被 `auto_gptq` 所支持。
|
||||
>
|
||||
> 比如, `WizardLM`,`vicuna` 和 `gpt4all` 模型的 `model_type` 皆为 `llama`, 因此这些模型皆被 `auto_gptq` 所支持。
|
||||
|
||||
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
|
||||
|------------------------------------|--------------|-----------|-----------|---------------|-----------------------------------------------------------------------------------|
|
||||
| bloom | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt2 | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt_neox | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| gptj | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| moss | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
|
||||
| opt | ✅ | ✅ | ✅ | ✅ | |
|
||||
| gpt_bigcode | ✅ | ✅ | ✅ | ✅ | |
|
||||
| codegen | ✅ | ✅ | ✅ | ✅ | |
|
||||
| falcon(RefinedWebModel/RefinedWeb) | ✅ | ✅ | ✅ | ✅ | |
|
||||
|
||||
## 支持的评估任务
|
||||
目前, `auto_gptq` 支持以下评估任务: `LanguageModelingTask`, `SequenceClassificationTask` 和 `TextSummarizationTask`;更多的评估任务即将到来!
|
||||
### 更多示例
|
||||
请转至 [examples](examples/README.md)以获取更多的示例。
|
||||
|
||||
## 致谢
|
||||
- 特别感谢 **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** 和 **Dan Alistarh** 提出 **GPTQ** 算法并开源[代码](https://github.com/IST-DASLab/gptq)。
|
||||
- 特别感谢 **qwopqwop200**, 本项目中涉及到模型量化的代码主要参考自 [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda)。
|
||||
|
||||
[](https://star-history.com/#PanQiWei/AutoGPTQ&Date)
|
|
@ -1,5 +1,2 @@
|
|||
__version__ = "0.5.0.dev0"
|
||||
from .modeling import BaseQuantizeConfig
|
||||
from .modeling import AutoGPTQForCausalLM
|
||||
from .utils.peft_utils import get_gptq_peft_model
|
||||
from .utils.exllama_utils import exllama_set_max_input_length
|
||||
|
|
|
@ -7,11 +7,3 @@ from .gptj import *
|
|||
from .llama import *
|
||||
from .moss import *
|
||||
from .opt import *
|
||||
from .rw import *
|
||||
from .gpt_bigcode import *
|
||||
from .codegen import *
|
||||
from .baichuan import *
|
||||
from .internlm import *
|
||||
from .qwen import *
|
||||
from .mistral import *
|
||||
from .mpt import *
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import copy
|
||||
import json
|
||||
import warnings
|
||||
import os
|
||||
from dataclasses import dataclass, field, fields
|
||||
from logging import getLogger
|
||||
from os.path import join, isfile, isdir
|
||||
from os.path import join, isfile
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import accelerate
|
||||
|
@ -13,21 +12,13 @@ import torch.nn as nn
|
|||
import transformers
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
from safetensors.torch import save_file as safe_save
|
||||
from safetensors.torch import load_file as safe_load
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils.hub import PushToHubMixin
|
||||
|
||||
from ._const import *
|
||||
from ._utils import *
|
||||
from ..nn_modules.qlinear import GeneralQuantLinear
|
||||
from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
|
||||
from ..quantization import GPTQ
|
||||
from ..utils.data_utils import collate_data
|
||||
from ..utils.import_utils import (
|
||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE, EXLLAMAV2_KERNELS_AVAILABLE
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -38,11 +29,8 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
group_size: int = field(default=-1)
|
||||
damp_percent: float = field(default=0.01)
|
||||
desc_act: bool = field(default=True)
|
||||
static_groups: bool = field(default=False)
|
||||
sym: bool = field(default=True)
|
||||
true_sequential: bool = field(default=True)
|
||||
model_name_or_path: Optional[str] = field(default=None)
|
||||
model_file_base_name: Optional[str] = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
fields_info = fields(self)
|
||||
|
@ -59,48 +47,9 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
json.dump(self.to_dict(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, save_dir: str, **kwargs):
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
quantize_config_filename = "quantize_config.json"
|
||||
if os.path.isdir(save_dir): # Local
|
||||
resolved_config_file = join(save_dir, quantize_config_filename)
|
||||
else: # Remote
|
||||
resolved_config_file = cached_file(
|
||||
save_dir,
|
||||
quantize_config_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
field_names = [field.name for field in fields(cls)]
|
||||
with open(resolved_config_file, "r", encoding="utf-8") as f:
|
||||
args_from_json = json.load(f)
|
||||
filtered_args = {}
|
||||
for key, val in args_from_json.items():
|
||||
if key in field_names:
|
||||
filtered_args[key] = val
|
||||
else:
|
||||
logger.warning(f"ignoring unknown parameter in {quantize_config_filename}: {key}.")
|
||||
return cls(**filtered_args)
|
||||
def from_pretrained(cls, save_dir: str):
|
||||
with open(join(save_dir, "quantize_config.json"), "r", encoding="utf-8") as f:
|
||||
return cls(**json.load(f))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
@ -108,11 +57,8 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
"group_size": self.group_size,
|
||||
"damp_percent": self.damp_percent,
|
||||
"desc_act": self.desc_act,
|
||||
"static_groups": self.static_groups,
|
||||
"sym": self.sym,
|
||||
"true_sequential": self.true_sequential,
|
||||
"model_name_or_path": self.model_name_or_path,
|
||||
"model_file_base_name": self.model_file_base_name,
|
||||
}
|
||||
|
||||
|
||||
|
@ -123,19 +69,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
inside_layer_modules: List[List[str]] = None
|
||||
lm_head_name: str = "lm_head"
|
||||
|
||||
fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
|
||||
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
quantized: bool,
|
||||
quantize_config: BaseQuantizeConfig,
|
||||
is_triton_backend: bool = False,
|
||||
injected_fused_attention: bool = False,
|
||||
injected_fused_mlp: bool = False,
|
||||
trainable: bool = False
|
||||
):
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
@ -144,11 +78,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
self.quantize_config = quantize_config
|
||||
self.config = self.model.config
|
||||
|
||||
self.is_triton_backend = is_triton_backend
|
||||
self.injected_fused_attention = injected_fused_attention
|
||||
self.injected_fused_mlp = injected_fused_mlp
|
||||
self.trainable = trainable
|
||||
|
||||
@property
|
||||
def quantized(self):
|
||||
return self._quantized
|
||||
|
@ -212,22 +141,17 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
|
||||
batch_size: int = 1,
|
||||
use_triton: bool = False,
|
||||
use_cuda_fp16: bool = True,
|
||||
autotune_warmup_after_quantized: bool = False,
|
||||
cache_examples_on_gpu: bool = True
|
||||
):
|
||||
if self.quantized:
|
||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("triton is not installed, reset use_triton to False")
|
||||
use_triton = False
|
||||
|
||||
device_map = self.hf_device_map
|
||||
if device_map:
|
||||
for name, device in device_map.items():
|
||||
if device == "cpu":
|
||||
logger.info(f"truly offloading {name} to cpu with hook.")
|
||||
module = get_module_by_name_suffix(self.model, name)
|
||||
module = get_module_by_name(self.model, name)
|
||||
remove_hook_from_module(module, recurse=True)
|
||||
accelerate.cpu_offload_with_hook(module, CUDA_0)
|
||||
|
||||
|
@ -255,8 +179,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
break
|
||||
layer_inputs.append(move_to_device(inp, self.data_device))
|
||||
attention_masks.append(kwargs["attention_mask"].to(self.data_device))
|
||||
pos_ids = kwargs.get("position_ids", None)
|
||||
if pos_ids is not None:
|
||||
if (pos_ids := kwargs.get("position_ids", None)) is not None:
|
||||
position_ids.append(move_to_device(pos_ids, self.data_device))
|
||||
one_kwargs = dict()
|
||||
for k, v in kwargs.items(): # make sure other arguments also be captured
|
||||
|
@ -272,7 +195,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
self.model.config.use_cache = False
|
||||
|
||||
num_batches = len(examples)
|
||||
layers = get_module_by_name_prefix(self.model, self.layers_block_name)
|
||||
layers = get_module_by_name(self.model, self.layers_block_name)
|
||||
|
||||
force_layer_back_to_cpu = False
|
||||
if get_device(layers[0]) == CPU:
|
||||
|
@ -282,7 +205,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
cur_layer_device = get_device(layers[0])
|
||||
ori_outside_layer_module_devices = {}
|
||||
for module_name in self.outside_layer_modules:
|
||||
module = get_module_by_name_prefix(self.model, module_name)
|
||||
module = get_module_by_name(self.model, module_name)
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
@ -306,7 +229,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
move_to_device(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device)
|
||||
for module_name in self.outside_layer_modules:
|
||||
module = get_module_by_name_prefix(self.model, module_name)
|
||||
module = get_module_by_name(self.model, module_name)
|
||||
if module is not None:
|
||||
move_to_device(module, ori_outside_layer_module_devices[module_name])
|
||||
|
||||
|
@ -341,7 +264,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
sym=self.quantize_config.sym,
|
||||
mse=False,
|
||||
)
|
||||
|
||||
def add_batch(name):
|
||||
def tmp(_, inp, out):
|
||||
gptq[name].add_batch(inp[0].data, out.data)
|
||||
|
@ -357,8 +279,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
additional_layer_inputs = {
|
||||
"attention_mask": layer_attention_mask
|
||||
}
|
||||
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
|
||||
if layer_position_ids is not None:
|
||||
if (
|
||||
layer_position_ids := None if not position_ids
|
||||
else move_to_device(position_ids[j], cur_layer_device)
|
||||
) is not None:
|
||||
additional_layer_inputs["position_ids"] = layer_position_ids
|
||||
for k, v in layer_input_kwargs[j].items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
@ -373,9 +297,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
logger.info(f'Quantizing {name} in layer {i + 1}/{len(layers)}...')
|
||||
scale, zero, g_idx = gptq[name].fasterquant(
|
||||
percdamp=self.quantize_config.damp_percent,
|
||||
group_size=self.quantize_config.group_size,
|
||||
actorder=self.quantize_config.desc_act,
|
||||
static_groups=self.quantize_config.static_groups
|
||||
groupsize=self.quantize_config.group_size,
|
||||
actorder=self.quantize_config.desc_act
|
||||
)
|
||||
quantizers[f'{self.layers_block_name}.{i}.{name}'] = (
|
||||
gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device),
|
||||
|
@ -391,8 +314,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
additional_layer_inputs = {
|
||||
"attention_mask": layer_attention_mask
|
||||
}
|
||||
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
|
||||
if layer_position_ids is not None:
|
||||
if (
|
||||
layer_position_ids := None if not position_ids
|
||||
else move_to_device(position_ids[j], cur_layer_device)
|
||||
) is not None:
|
||||
additional_layer_inputs["position_ids"] = layer_position_ids
|
||||
for k, v in layer_input_kwargs[j].items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
@ -418,14 +343,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
bits=self.quantize_config.bits,
|
||||
group_size=self.quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=self.quantize_config.desc_act,
|
||||
warmup_triton=autotune_warmup_after_quantized,
|
||||
autotune_warmup=autotune_warmup_after_quantized,
|
||||
force_layer_back_to_cpu=force_layer_back_to_cpu
|
||||
)
|
||||
if device_map:
|
||||
self.model = remove_hook_from_module(self.model, recurse=True)
|
||||
self.model = simple_dispatch_model(self.model, device_map)
|
||||
self.model = accelerate.dispatch_model(self.model, device_map, offload_buffers=True)
|
||||
self.model.config.use_cache = forward_pass_use_cache
|
||||
|
||||
self._quantized = True
|
||||
|
@ -434,18 +358,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
@property
|
||||
def device(self):
|
||||
if not self.hf_device_map:
|
||||
return self.model.device
|
||||
else:
|
||||
device = [d for d in self.hf_device_map.values() if d not in {'cpu', 'disk'}][0]
|
||||
return torch.device(device)
|
||||
|
||||
def to(self, device: Union[str, torch.device]):
|
||||
self.model.to(device)
|
||||
return self
|
||||
return self.model.to(device)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
def forward(self, **kwargs):
|
||||
return self.model(**kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""shortcut for model.generate"""
|
||||
|
@ -456,78 +375,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"""shortcut for model.prepare_inputs_for_generation"""
|
||||
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
save_dir: Optional[str] = None,
|
||||
use_safetensors: Optional[bool] = True,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
commit_message: Optional[str] = "Upload of AutoGPTQ quantized model",
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: Optional[bool] = False,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model to the Hugging Face Hub.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
pushing to a given organization.
|
||||
save_dir (`str`, *optional*):
|
||||
The name of the local folder to save the model to.
|
||||
If the model has already been saved, this parameter can be omitted.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Save the model using `safetensors`.
|
||||
If the model has already been saved, this parameter can be omitted.
|
||||
safetensors_metadata: (`dict`, *optional*, defaults to `None`):
|
||||
Pass optional metadata dictionary to be saved in the `safetensors` model file(s).
|
||||
Metadata is optional and is purely for informational purposes. It does not affect inference.
|
||||
If `None`, no metadata will be saved.
|
||||
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
|
||||
Message to commit while pushing.
|
||||
use_auth_token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
|
||||
is not specified.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private.
|
||||
token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
if (self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)) and save_dir is None:
|
||||
raise ValueError("Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading.")
|
||||
|
||||
if save_dir is not None:
|
||||
logger.info(f"Saving model to {save_dir}")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model"
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
|
||||
if self.quantize_config.model_name_or_path is not None:
|
||||
work_dir = self.quantize_config.model_name_or_path
|
||||
operations = [
|
||||
CommitOperationAdd(path_or_fileobj=join(work_dir, f), path_in_repo=f)
|
||||
for f in os.listdir(work_dir)
|
||||
]
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
token=use_auth_token,
|
||||
create_pr=create_pr,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
|
||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False):
|
||||
"""save quantized model and configs to local disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
@ -536,60 +384,23 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
self.model.to(CPU)
|
||||
|
||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
model_save_name = f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
if use_safetensors:
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
model_save_name += ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
if safetensors_metadata is None:
|
||||
safetensors_metadata = {}
|
||||
elif not isinstance(safetensors_metadata, dict):
|
||||
raise TypeError("safetensors_metadata must be a dictionary.")
|
||||
safe_save(state_dict, join(save_dir, model_save_name))
|
||||
else:
|
||||
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
|
||||
new_safetensors_metadata = {}
|
||||
converted_keys = False
|
||||
for key, value in safetensors_metadata.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
converted_keys = True
|
||||
try:
|
||||
new_key = str(key)
|
||||
new_value = str(value)
|
||||
except Exception as e:
|
||||
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
||||
if new_key in new_safetensors_metadata:
|
||||
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
||||
new_safetensors_metadata[new_key] = new_value
|
||||
safetensors_metadata = new_safetensors_metadata
|
||||
if converted_keys:
|
||||
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
||||
|
||||
# Format is required to enable Accelerate to load the metadata
|
||||
# otherwise it raises an OSError
|
||||
safetensors_metadata['format'] = "pt"
|
||||
|
||||
# Store the quantization configuration as safetensors metadata
|
||||
from auto_gptq import __version__
|
||||
safetensors_metadata['auto_gptq_version'] = str(__version__)
|
||||
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
|
||||
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
|
||||
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
|
||||
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
|
||||
|
||||
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
||||
else:
|
||||
model_save_name = model_base_name + ".bin"
|
||||
model_save_name += ".bin"
|
||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
self.quantize_config.save_pretrained(save_dir)
|
||||
self.quantize_config.model_name_or_path = save_dir
|
||||
self.quantize_config.model_file_base_name = model_base_name
|
||||
|
||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
|
||||
"""alias of save_quantized"""
|
||||
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
self.save_quantized(save_dir, use_safetensors)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
@ -597,8 +408,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
pretrained_model_name_or_path: str,
|
||||
quantize_config: BaseQuantizeConfig,
|
||||
max_memory: Optional[dict] = None,
|
||||
trust_remote_code: bool = False,
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
**model_init_kwargs
|
||||
):
|
||||
"""load un-quantized pretrained model to cpu"""
|
||||
|
@ -613,35 +422,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = model_init_kwargs.pop("cache_dir", None)
|
||||
force_download = model_init_kwargs.pop("force_download", False)
|
||||
resume_download = model_init_kwargs.pop("resume_download", False)
|
||||
proxies = model_init_kwargs.pop("proxies", None)
|
||||
local_files_only = model_init_kwargs.pop("local_files_only", False)
|
||||
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
|
||||
revision = model_init_kwargs.pop("revision", None)
|
||||
subfolder = model_init_kwargs.pop("subfolder", "")
|
||||
commit_hash = model_init_kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
# enforce some values despite user specified
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
model_init_kwargs["trust_remote_code"] = trust_remote_code
|
||||
model_init_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
model_init_kwargs["trust_remote_code"] = True
|
||||
if max_memory:
|
||||
if "disk" in max_memory:
|
||||
raise NotImplementedError("disk offload not support yet.")
|
||||
|
@ -671,9 +458,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
|
@ -691,180 +476,55 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
@classmethod
|
||||
def from_quantized(
|
||||
cls,
|
||||
model_name_or_path: Optional[str],
|
||||
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
||||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
save_dir: str,
|
||||
device: str = "cpu",
|
||||
use_safetensors: bool = False,
|
||||
use_triton: bool = False,
|
||||
use_qigen: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
use_cuda_fp16: bool = True,
|
||||
max_memory: Optional[dict] = None,
|
||||
device_map: Optional[str] = None,
|
||||
quantize_config: Optional[BaseQuantizeConfig] = None,
|
||||
model_basename: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
**kwargs
|
||||
trust_remote_code: bool = False
|
||||
):
|
||||
"""load quantized model from local disk"""
|
||||
if use_triton:
|
||||
from ..nn_modules.qlinear_triton import autotune_warmup_linear
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
if use_qigen and not QIGEN_AVAILABLE:
|
||||
logger.warning("Qigen is not installed, reset use_qigen to False.")
|
||||
use_qigen = False
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||
use_triton = False
|
||||
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
|
||||
logger.warning(
|
||||
"Exllama kernel is not installed, reset disable_exllama to True. "
|
||||
"This may because you installed auto_gptq using a pre-build wheel "
|
||||
"on Windows, in which exllama_kernels are not compiled. To use "
|
||||
"exllama_kernels to further speedup inference, you can re-install "
|
||||
"auto_gptq from source."
|
||||
)
|
||||
disable_exllama = True
|
||||
if not disable_exllamav2 and not EXLLAMAV2_KERNELS_AVAILABLE:
|
||||
logger.warning(
|
||||
"Exllamav2 kernel is not installed, reset disable_exllamav2 to True. "
|
||||
"This may because you installed auto_gptq using a pre-build wheel "
|
||||
"on Windows, in which exllama_kernels are not compiled. To use "
|
||||
"exllama_kernels to further speedup inference, you can re-install "
|
||||
"auto_gptq from source."
|
||||
)
|
||||
disable_exllamav2 = True
|
||||
if not AUTOGPTQ_CUDA_AVAILABLE:
|
||||
logger.warning(
|
||||
"CUDA kernels for auto_gptq are not installed, this will result in "
|
||||
"very slow inference speed. This may because:\n"
|
||||
"1. You disabled CUDA extensions compilation by setting BUILD_CUDA_EXT=0 when install auto_gptq from source.\n"
|
||||
"2. You are using pytorch without CUDA support.\n"
|
||||
"3. CUDA and nvcc are not installed in your device."
|
||||
)
|
||||
|
||||
if use_qigen and QIGEN_AVAILABLE:
|
||||
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||
inject_fused_attention = False
|
||||
inject_fused_mlp = False
|
||||
use_triton = False
|
||||
disable_exllama = True
|
||||
disable_exllamav2 = True
|
||||
|
||||
if not disable_exllamav2 and not disable_exllama:
|
||||
logger.warning(
|
||||
"You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False"
|
||||
)
|
||||
disable_exllama = True
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||
logger.warning("use_triton will force moving the whole model to GPU, make sure you have enough VRAM.")
|
||||
device = "cuda:0"
|
||||
|
||||
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
if quantize_config is None:
|
||||
quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)
|
||||
quantize_config = BaseQuantizeConfig.from_pretrained(save_dir)
|
||||
|
||||
if model_basename is None:
|
||||
if quantize_config.model_file_base_name:
|
||||
model_basename = quantize_config.model_file_base_name
|
||||
else:
|
||||
model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g"
|
||||
|
||||
quantize_config.model_name_or_path = model_name_or_path
|
||||
quantize_config.model_file_base_name = model_basename
|
||||
model_save_name = join(save_dir, model_basename)
|
||||
|
||||
extensions = []
|
||||
if use_safetensors:
|
||||
extensions.append(".safetensors")
|
||||
model_save_name += ".safetensors"
|
||||
else:
|
||||
extensions += [".bin", ".pt"]
|
||||
model_save_name += ".bin"
|
||||
|
||||
model_name_or_path = str(model_name_or_path)
|
||||
is_local = isdir(model_name_or_path)
|
||||
if not isfile(model_save_name):
|
||||
raise FileNotFoundError(f"Could not find model at {model_save_name}")
|
||||
|
||||
resolved_archive_file = None
|
||||
if is_local:
|
||||
model_save_name = join(model_name_or_path, model_basename)
|
||||
for ext in extensions:
|
||||
if isfile(model_save_name + ext):
|
||||
resolved_archive_file = model_save_name + ext
|
||||
break
|
||||
else: # remote
|
||||
for ext in extensions:
|
||||
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
||||
if resolved_archive_file is not None:
|
||||
break
|
||||
|
||||
if resolved_archive_file is None: # Could not find a model file to use
|
||||
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||
|
||||
model_save_name = resolved_archive_file
|
||||
|
||||
if (not disable_exllama or not disable_exllamav2) and trainable:
|
||||
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
disable_exllama = True
|
||||
disable_exllamav2 = True
|
||||
|
||||
elif not use_triton and trainable:
|
||||
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
|
||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
if torch_dtype is None:
|
||||
if not use_qigen:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
if not use_qigen:
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
transformers.modeling_utils._init_weights = False
|
||||
|
||||
init_contexts = [no_init_weights()]
|
||||
if low_cpu_mem_usage:
|
||||
init_contexts.append(accelerate.init_empty_weights(include_buffers=False))
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
torch.set_default_dtype(torch.half)
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code)
|
||||
torch.set_default_dtype(torch.float)
|
||||
layers = find_layers(model)
|
||||
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
|
||||
for name in list(layers.keys()):
|
||||
|
@ -872,102 +532,19 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
||||
del layers[name]
|
||||
|
||||
make_quant(
|
||||
model,
|
||||
layers,
|
||||
quantize_config.bits,
|
||||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
make_quant(model, layers, quantize_config.bits, quantize_config.group_size, use_triton=use_triton, desc_act=quantize_config.desc_act)
|
||||
model.tie_weights()
|
||||
|
||||
# == step3: load checkpoint and dispatch == #
|
||||
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
raise ValueError(
|
||||
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
||||
"'sequential'."
|
||||
)
|
||||
if isinstance(device_map, dict):
|
||||
max_memory = None
|
||||
else:
|
||||
if device is None and not device_map and not max_memory:
|
||||
if max_memory and not device_map:
|
||||
device_map = "auto"
|
||||
if device is not None:
|
||||
device = torch.device(device)
|
||||
if not max_memory and not device_map:
|
||||
device_map = {"": device.index if device.type == "cuda" else device.type}
|
||||
if not isinstance(device_map, dict) and device_map != "sequential":
|
||||
max_memory = accelerate.utils.get_balanced_memory(
|
||||
model=model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type],
|
||||
low_zero=(device_map == "balanced_low_0")
|
||||
)
|
||||
if not isinstance(device_map, dict):
|
||||
device_map = accelerate.infer_auto_device_map(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type]
|
||||
device_map = {"": device}
|
||||
|
||||
model = accelerate.load_checkpoint_and_dispatch(
|
||||
model, model_save_name, device_map, max_memory, no_split_module_classes=[cls.layer_type]
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits)
|
||||
|
||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||
model,
|
||||
checkpoint=model_save_name,
|
||||
device_map=device_map,
|
||||
offload_state_dict=True,
|
||||
offload_buffers=True
|
||||
)
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
else:
|
||||
if quantize_config.desc_act:
|
||||
NotImplementedError('desc_act=True is not yet supported.')
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
layers = find_layers(model)
|
||||
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
|
||||
for name in list(layers.keys()):
|
||||
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
||||
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
||||
del layers[name]
|
||||
|
||||
if model_save_name.endswith('.safetensors'):
|
||||
checkpoint = safe_load(model_save_name)
|
||||
else:
|
||||
checkpoint = torch.load(model_save_name)
|
||||
make_quant(
|
||||
model,
|
||||
layers,
|
||||
quantize_config.bits,
|
||||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable,
|
||||
use_qigen=True
|
||||
)
|
||||
preprocess_checkpoint_qigen(
|
||||
model,
|
||||
layers,
|
||||
quantize_config.bits,
|
||||
quantize_config.group_size,
|
||||
checkpoint
|
||||
)
|
||||
model.load_state_dict(checkpoint)
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
|
@ -979,94 +556,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
|
||||
# == step5: (optional) inject optimized module == #
|
||||
if inject_fused_attention:
|
||||
if cls.fused_attn_module_type is None:
|
||||
inject_fused_attention = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
|
||||
else:
|
||||
cls.fused_attn_module_type.inject_to_model(
|
||||
model,
|
||||
use_triton=use_triton,
|
||||
group_size=quantize_config.group_size,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable,
|
||||
bits=quantize_config.bits,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2
|
||||
)
|
||||
if inject_fused_mlp:
|
||||
if cls.fused_mlp_module_type is None:
|
||||
inject_fused_mlp = False
|
||||
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
|
||||
else:
|
||||
cls.fused_mlp_module_type.inject_to_model(
|
||||
model,
|
||||
use_triton=use_triton
|
||||
)
|
||||
|
||||
# Any post-initialization that require device information, for example buffers initialization on device.
|
||||
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)
|
||||
|
||||
model.eval()
|
||||
# == step6: (optional) warmup triton == #
|
||||
if use_triton and warmup_triton:
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
QuantLinear.warmup(model, seqlen=model.seqlen)
|
||||
|
||||
if inject_fused_mlp and cls.fused_mlp_module_type is not None:
|
||||
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)
|
||||
if use_triton:
|
||||
autotune_warmup_linear(model, seqlen=model.seqlen)
|
||||
|
||||
# == step7: make model compatible with peft
|
||||
cls.make_sure_compatible_with_peft(
|
||||
model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits
|
||||
)
|
||||
return cls(model, True, quantize_config)
|
||||
|
||||
return cls(
|
||||
model,
|
||||
True,
|
||||
quantize_config,
|
||||
is_triton_backend=use_triton,
|
||||
injected_fused_attention=inject_fused_attention,
|
||||
injected_fused_mlp=inject_fused_mlp and use_triton,
|
||||
trainable=trainable
|
||||
)
|
||||
|
||||
def warmup_triton(self, enabled: bool = True):
|
||||
if not enabled:
|
||||
return
|
||||
if not TRITON_AVAILABLE:
|
||||
logger.warning(f"triton is not available, skip warmup stage directly.")
|
||||
return
|
||||
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
|
||||
|
||||
if self.fused_mlp_module_type is not None:
|
||||
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)
|
||||
|
||||
def enable_trainable_mode(self, enabled: bool = True):
|
||||
if not self.is_triton_backend and enabled:
|
||||
raise NotImplementedError("For now, trainable mode only supports triton backend.")
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "trainable"):
|
||||
setattr(m, "trainable", enabled)
|
||||
|
||||
def disable_trainable_mode(self):
|
||||
self.enable_trainable_mode(enabled=False)
|
||||
|
||||
@staticmethod
|
||||
def make_sure_compatible_with_peft(model: PreTrainedModel, use_triton: bool, desc_act: bool, group_size: int, bits: int):
|
||||
GeneralQuantLinear.inject_to_model(
|
||||
model,
|
||||
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
|
||||
)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
return getattr(self.model, item)
|
||||
|
||||
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]
|
||||
|
|
|
@ -1,36 +1,13 @@
|
|||
from packaging.version import parse as parse_version
|
||||
|
||||
from torch import device
|
||||
|
||||
from ..utils.import_utils import compare_transformers_version
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"bloom",
|
||||
"gptj",
|
||||
"gpt2",
|
||||
"gpt_neox",
|
||||
"opt",
|
||||
"moss",
|
||||
"gpt_bigcode",
|
||||
"codegen",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
"baichuan",
|
||||
"internlm",
|
||||
"qwen",
|
||||
"mpt",
|
||||
]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
||||
if parse_version(transformers_version) >= parse_version("v4.28.0"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
if compare_transformers_version("v4.33.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("falcon")
|
||||
if compare_transformers_version("v4.34.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("mistral")
|
||||
|
||||
|
||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||
|
||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
|
||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import Union, Optional
|
||||
from typing import Union
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
import transformers
|
||||
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -28,8 +27,8 @@ def move_to_device(obj: Union[torch.Tensor, nn.Module], device: torch.device):
|
|||
def find_layers(module, layers=None, name=''):
|
||||
if not layers:
|
||||
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]
|
||||
for layer in layers:
|
||||
if isinstance(module,layer):
|
||||
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
|
@ -37,33 +36,20 @@ def find_layers(module, layers=None, name=''):
|
|||
return res
|
||||
|
||||
|
||||
def get_module_by_name_prefix(model, module_name: str):
|
||||
def get_module_by_name(model, module_name: str):
|
||||
for name, module in model.named_modules():
|
||||
if name.startswith(module_name):
|
||||
return module
|
||||
|
||||
|
||||
def get_module_by_name_suffix(model, module_name: str):
|
||||
for name, module in model.named_modules():
|
||||
if name.endswith(module_name):
|
||||
return module
|
||||
|
||||
|
||||
def make_quant(
|
||||
module,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
name='',
|
||||
use_triton: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
use_qigen: bool = False,
|
||||
use_cuda_fp16: bool = True,
|
||||
desc_act: bool = False,
|
||||
trainable: bool = False
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen)
|
||||
def make_quant(module, names, bits, groupsize, name='', use_triton=False, desc_act=False):
|
||||
if use_triton:
|
||||
from ..nn_modules.qlinear_triton import QuantLinear
|
||||
else:
|
||||
if not desc_act or groupsize == -1:
|
||||
from ..nn_modules.qlinear_old import QuantLinear
|
||||
else:
|
||||
from ..nn_modules.qlinear import QuantLinear
|
||||
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
|
@ -73,109 +59,21 @@ def make_quant(
|
|||
if name1 in names:
|
||||
ori_layer_device = get_device(getattr(module, attr))
|
||||
delattr(module, attr)
|
||||
if isinstance(tmp,nn.Linear):
|
||||
if type(tmp) == nn.Linear:
|
||||
in_features = tmp.in_features
|
||||
out_features = tmp.out_features
|
||||
elif isinstance(tmp,nn.Conv2d):
|
||||
elif type(tmp) == nn.Conv2d:
|
||||
in_features = tmp.in_channels
|
||||
out_features = tmp.out_channels
|
||||
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
|
||||
elif type(tmp) == transformers.pytorch_utils.Conv1D:
|
||||
in_features = tmp.weight.shape[0]
|
||||
out_features = tmp.weight.shape[1]
|
||||
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
|
||||
new_layer = QuantLinear(
|
||||
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
||||
)
|
||||
else:
|
||||
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
|
||||
new_layer = QuantLinear(bits, groupsize, in_features, out_features, tmp.bias is not None)
|
||||
new_layer.device = ori_layer_device
|
||||
setattr(module, attr, new_layer.to(ori_layer_device))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(
|
||||
child,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
name + '.' + name1 if name != '' else name1,
|
||||
use_triton=use_triton,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=desc_act,
|
||||
trainable=trainable,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_qigen=use_qigen
|
||||
)
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1, use_triton=use_triton, desc_act=desc_act)
|
||||
|
||||
def preprocess_checkpoint_qigen(
|
||||
module,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
checkpoint,
|
||||
name='',
|
||||
):
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
except ImportError:
|
||||
logger.error('cQIGen not installed.')
|
||||
raise
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
||||
if isinstance(module, QuantLinear):
|
||||
in_features = module.infeatures
|
||||
out_features = module.outfeatures
|
||||
|
||||
zeros = checkpoint[name + '.qzeros']
|
||||
scales = checkpoint[name + '.scales'].float()
|
||||
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != out_features:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != out_features:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
|
||||
del checkpoint[name + '.qzeros']
|
||||
del checkpoint[name + '.g_idx']
|
||||
if name + '.bias' in checkpoint:
|
||||
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
|
||||
else:
|
||||
checkpoint[name + '.bias'] = torch.zeros(out_features)
|
||||
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
|
||||
if bits == 4:
|
||||
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
|
||||
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||
elif bits == 3:
|
||||
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
|
||||
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
|
||||
elif bits == 2:
|
||||
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
|
||||
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||
checkpoint[name + '.qweight'] = qweight
|
||||
return
|
||||
|
||||
for name1, child in module.named_children():
|
||||
preprocess_checkpoint_qigen(
|
||||
child,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
checkpoint,
|
||||
name + '.' + name1 if name != '' else name1,
|
||||
)
|
||||
|
||||
def pack_model(
|
||||
model,
|
||||
|
@ -183,20 +81,24 @@ def pack_model(
|
|||
bits,
|
||||
group_size,
|
||||
use_triton=False,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
warmup_triton: bool = False,
|
||||
autotune_warmup: bool = False,
|
||||
force_layer_back_to_cpu: bool = False
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=False, disable_exllamav2=True)
|
||||
|
||||
if use_triton:
|
||||
from ..nn_modules.qlinear_triton import QuantLinear, autotune_warmup_linear
|
||||
else:
|
||||
if not desc_act or group_size == -1:
|
||||
from ..nn_modules.qlinear_old import QuantLinear
|
||||
else:
|
||||
from ..nn_modules.qlinear import QuantLinear
|
||||
if force_layer_back_to_cpu:
|
||||
model.to(CPU)
|
||||
|
||||
logger.info('Packing model...')
|
||||
layers = find_layers(model)
|
||||
layers = {n: layers[n] for n in quantizers}
|
||||
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act, disable_exllama=False, disable_exllamav2=True)
|
||||
make_quant(model, quantizers, bits, group_size, use_triton=use_triton)
|
||||
qlayers = find_layers(model, [QuantLinear])
|
||||
for name in qlayers:
|
||||
logger.info(name)
|
||||
|
@ -209,185 +111,27 @@ def pack_model(
|
|||
qlayers[name].to(layer_device)
|
||||
logger.info('Model packed.')
|
||||
|
||||
if use_triton and warmup_triton:
|
||||
if use_triton and autotune_warmup:
|
||||
logger.warning(
|
||||
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
|
||||
)
|
||||
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
|
||||
autotune_warmup_linear(model.to(CUDA_0), seqlen=model.seqlen)
|
||||
|
||||
|
||||
def check_and_get_model_type(model_dir, trust_remote_code=False):
|
||||
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
|
||||
def check_and_get_model_type(model_dir):
|
||||
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
model_type = config.model_type
|
||||
return model_type
|
||||
|
||||
|
||||
def simple_dispatch_model(model, device_map):
|
||||
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
|
||||
|
||||
if "" in device_map:
|
||||
d = device_map[""]
|
||||
model = model.to(torch.device(d))
|
||||
model.hf_device_map = device_map
|
||||
return model
|
||||
|
||||
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
|
||||
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
|
||||
main_device = "cpu"
|
||||
else:
|
||||
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
||||
|
||||
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
|
||||
prev_hook = None
|
||||
for idx, (n, d) in enumerate(cpu_offload_group):
|
||||
m = get_module_by_name_suffix(model, n)
|
||||
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
|
||||
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
|
||||
if len(cpu_offload_group) > 1:
|
||||
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
|
||||
|
||||
for n, d in device_map.items():
|
||||
m = get_module_by_name_suffix(model, n)
|
||||
if d != "cpu":
|
||||
d = torch.device(d)
|
||||
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
|
||||
add_hook_to_module(m, hook)
|
||||
accelerate.utils.modeling.retie_parameters(model, tied_params)
|
||||
model.hf_device_map = device_map
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
|
||||
"""
|
||||
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
|
||||
"""
|
||||
device_to_buffers_size = {}
|
||||
|
||||
model_uses_exllama = False
|
||||
for name, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
|
||||
model_uses_exllama = True
|
||||
device = submodule.qweight.device
|
||||
if device not in device_to_buffers_size:
|
||||
device_to_buffers_size[device] = {
|
||||
"max_dq_buffer_size": 1,
|
||||
"max_inner_outer_dim": 1
|
||||
}
|
||||
|
||||
if not use_act_order:
|
||||
submodule._use_act_order = False
|
||||
else:
|
||||
submodule._use_act_order = True
|
||||
|
||||
# Disable this heuristic for detecting act_order, but it could be used instead of the config.
|
||||
"""
|
||||
if submodule.g_idx is None:
|
||||
submodule.act_order = False
|
||||
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
|
||||
submodule.g_idx = None
|
||||
submodule.act_order = False
|
||||
else:
|
||||
submodule.act_order = True
|
||||
"""
|
||||
|
||||
device_to_buffers_size[device]["max_dq_buffer_size"] = max(device_to_buffers_size[device]["max_dq_buffer_size"], submodule.qweight.numel() * 8)
|
||||
|
||||
if use_act_order:
|
||||
device_to_buffers_size[device]["max_inner_outer_dim"] = max(device_to_buffers_size[device]["max_inner_outer_dim"], submodule.infeatures, submodule.outfeatures)
|
||||
|
||||
if model_uses_exllama:
|
||||
# To be honest this is quite ugly, not proud of this.
|
||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
|
||||
device_to_buffers = {}
|
||||
|
||||
if use_act_order:
|
||||
if max_input_length is None:
|
||||
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
else:
|
||||
max_input_len = max_input_length
|
||||
else:
|
||||
if max_input_length is not None:
|
||||
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")
|
||||
max_input_len = 1
|
||||
|
||||
for device, buffers_size in device_to_buffers_size.items():
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
device_to_buffers[device] = {
|
||||
"temp_state": torch.zeros((max_input_len, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
|
||||
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
|
||||
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
|
||||
}
|
||||
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
model.device_to_buffers = device_to_buffers
|
||||
|
||||
for device, buffers in model.device_to_buffers.items():
|
||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
# The buffers need to have been initialized first before calling make_q4.
|
||||
for name, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
|
||||
submodule.post_init()
|
||||
|
||||
## exllamav2
|
||||
fixed_bytes = {}
|
||||
model_uses_exllamav2 = False
|
||||
|
||||
for _, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||
model_uses_exllamav2 = True
|
||||
device = submodule.qweight.device
|
||||
scratch_fixed = submodule.scratch_space_fixed()
|
||||
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device,0))
|
||||
|
||||
if model_uses_exllamav2:
|
||||
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
|
||||
device_tensors = {}
|
||||
for device, scratch_bytes in fixed_bytes.items():
|
||||
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
|
||||
|
||||
# have persistent buffers, otherwise we will get OOM
|
||||
model.device_tensors = device_tensors
|
||||
|
||||
for _, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||
device = submodule.qweight.device
|
||||
submodule.post_init(temp_dq = model.device_tensors[device])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size, bits: int):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
|
||||
m.register_buffer('bias', torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_device",
|
||||
"move_to_device",
|
||||
"find_layers",
|
||||
"get_module_by_name_prefix",
|
||||
"get_module_by_name_suffix",
|
||||
"get_module_by_name",
|
||||
"make_quant",
|
||||
"preprocess_checkpoint_qigen",
|
||||
"pack_model",
|
||||
"autogptq_post_init",
|
||||
"check_and_get_model_type",
|
||||
"simple_dispatch_model",
|
||||
"make_sure_no_tensor_in_meta_device"
|
||||
"check_and_get_model_type"
|
||||
]
|
||||
|
|
|
@ -1,23 +1,15 @@
|
|||
from inspect import signature
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
|
||||
from ._utils import check_and_get_model_type
|
||||
from .bloom import BloomGPTQForCausalLM
|
||||
from .codegen import CodeGenGPTQForCausalLM
|
||||
from .gpt_neox import GPTNeoXGPTQForCausalLM
|
||||
from .gptj import GPTJGPTQForCausalLM
|
||||
from .gpt2 import GPT2GPTQForCausalLM
|
||||
from .llama import LlamaGPTQForCausalLM
|
||||
from .moss import MOSSGPTQForCausalLM
|
||||
from .opt import OPTGPTQForCausalLM
|
||||
from .rw import RWGPTQForCausalLM
|
||||
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
||||
from .baichuan import BaiChuanGPTQForCausalLM
|
||||
from .internlm import InternLMGPTQForCausalLM
|
||||
from .qwen import QwenGPTQForCausalLM
|
||||
from .mistral import MistralGPTQForCausalLM
|
||||
from .mpt import MPTGPTQForCausalLM
|
||||
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
"bloom": BloomGPTQForCausalLM,
|
||||
|
@ -26,17 +18,7 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"gpt2": GPT2GPTQForCausalLM,
|
||||
"llama": LlamaGPTQForCausalLM,
|
||||
"opt": OPTGPTQForCausalLM,
|
||||
"moss": MOSSGPTQForCausalLM,
|
||||
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
||||
"codegen": CodeGenGPTQForCausalLM,
|
||||
"RefinedWebModel": RWGPTQForCausalLM,
|
||||
"RefinedWeb": RWGPTQForCausalLM,
|
||||
"falcon": RWGPTQForCausalLM,
|
||||
"baichuan": BaiChuanGPTQForCausalLM,
|
||||
"internlm": InternLMGPTQForCausalLM,
|
||||
"qwen": QwenGPTQForCausalLM,
|
||||
"mistral": MistralGPTQForCausalLM,
|
||||
"mpt": MPTGPTQForCausalLM,
|
||||
"moss": MOSSGPTQForCausalLM
|
||||
}
|
||||
|
||||
|
||||
|
@ -54,82 +36,40 @@ class AutoGPTQForCausalLM:
|
|||
pretrained_model_name_or_path: str,
|
||||
quantize_config: BaseQuantizeConfig,
|
||||
max_memory: Optional[dict] = None,
|
||||
trust_remote_code: bool = False,
|
||||
**model_init_kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(
|
||||
pretrained_model_name_or_path, trust_remote_code
|
||||
)
|
||||
model_type = check_and_get_model_type(pretrained_model_name_or_path)
|
||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
quantize_config=quantize_config,
|
||||
max_memory=max_memory,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_init_kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_quantized(
|
||||
cls,
|
||||
model_name_or_path: Optional[str],
|
||||
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
|
||||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
save_dir: str,
|
||||
device: str = "cpu",
|
||||
use_safetensors: bool = False,
|
||||
use_triton: bool = False,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
use_cuda_fp16: bool = True,
|
||||
max_memory: Optional[dict] = None,
|
||||
device_map: Optional[str] = None,
|
||||
quantize_config: Optional[BaseQuantizeConfig] = None,
|
||||
model_basename: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
**kwargs
|
||||
trust_remote_code: bool = False
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
|
||||
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
||||
# A static list of kwargs needed for huggingface_hub
|
||||
huggingface_kwargs = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"local_files_only",
|
||||
"use_auth_token",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"_raise_exceptions_for_missing_entries",
|
||||
"_commit_hash"
|
||||
]
|
||||
# TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
|
||||
keywords = {
|
||||
key: kwargs[key]
|
||||
for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
|
||||
if key in kwargs
|
||||
}
|
||||
return quant_func(
|
||||
model_name_or_path=model_name_or_path,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
model_type = check_and_get_model_type(save_dir)
|
||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
|
||||
save_dir=save_dir,
|
||||
device=device,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
use_safetensors=use_safetensors,
|
||||
use_triton=use_triton,
|
||||
inject_fused_attention=inject_fused_attention,
|
||||
inject_fused_mlp=inject_fused_mlp,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
max_memory=max_memory,
|
||||
device_map=device_map,
|
||||
quantize_config=quantize_config,
|
||||
model_basename=model_basename,
|
||||
use_safetensors=use_safetensors,
|
||||
trust_remote_code=trust_remote_code,
|
||||
warmup_triton=warmup_triton,
|
||||
trainable=trainable,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
**keywords
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "DecoderLayer"
|
||||
layers_block_name = "model.layers"
|
||||
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||
inside_layer_modules = [
|
||||
["self_attn.W_pack"],
|
||||
["self_attn.o_proj"],
|
||||
["mlp.up_proj", "mlp.gate_proj"],
|
||||
["mlp.down_proj"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["BaiChuanGPTQForCausalLM"]
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "CodeGenBlock"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
|
||||
inside_layer_modules = [
|
||||
["attn.qkv_proj"],
|
||||
["attn.out_proj"],
|
||||
["mlp.fc_in"],
|
||||
["mlp.fc_out"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["CodeGenGPTQForCausalLM"]
|
|
@ -1,17 +0,0 @@
|
|||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "GPTBigCodeBlock"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = [
|
||||
"transformer.wpe", "transformer.wte", "transformer.ln_f"
|
||||
]
|
||||
inside_layer_modules = [
|
||||
["attn.c_attn"],
|
||||
["attn.c_proj"],
|
||||
["mlp.c_fc"],
|
||||
["mlp.c_proj"]
|
||||
]
|
||||
|
||||
__all__ = ["GPTBigCodeGPTQForCausalLM"]
|
|
@ -1,5 +1,4 @@
|
|||
from ._base import *
|
||||
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
|
||||
|
||||
|
||||
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -13,7 +12,5 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.fc_out"]
|
||||
]
|
||||
|
||||
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel
|
||||
|
||||
|
||||
__all__ = ["GPTJGPTQForCausalLM"]
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class InternLMGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "InternLMDecoderLayer"
|
||||
layers_block_name = "model.layers"
|
||||
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||
inside_layer_modules = [
|
||||
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||
["self_attn.o_proj"],
|
||||
["mlp.up_proj", "mlp.gate_proj"],
|
||||
["mlp.down_proj"],
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["InternLMGPTQForCausalLM"]
|
|
@ -1,16 +1,4 @@
|
|||
from logging import getLogger
|
||||
|
||||
from ._base import *
|
||||
from ..utils.import_utils import compare_transformers_version
|
||||
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
|
||||
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
|
||||
else:
|
||||
FusedLlamaAttentionForQuantizedModel = None
|
||||
FusedLlamaMLPForQuantizedModel = None
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
|
@ -24,8 +12,5 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
|||
["mlp.down_proj"]
|
||||
]
|
||||
|
||||
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
|
||||
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
|
||||
|
||||
|
||||
__all__ = ["LlamaGPTQForCausalLM"]
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "MistralDecoderLayer"
|
||||
layers_block_name = "model.layers"
|
||||
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||
inside_layer_modules = [
|
||||
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||
["self_attn.o_proj"],
|
||||
["mlp.up_proj", "mlp.gate_proj"],
|
||||
["mlp.down_proj"],
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["MistralGPTQForCausalLM"]
|
|
@ -1,18 +0,0 @@
|
|||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "MPTBlock"
|
||||
layers_block_name = "transformer.blocks"
|
||||
outside_layer_modules = [
|
||||
"transformer.wte", "transformer.norm_f"
|
||||
]
|
||||
|
||||
inside_layer_modules = [
|
||||
["attn.Wqkv"],
|
||||
["attn.out_proj"],
|
||||
["ffn.up_proj"],
|
||||
["ffn.down_proj"]
|
||||
]
|
||||
|
||||
__all__ = ["MPTGPTQForCausalLM"]
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class QwenGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "QWenBlock"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f", "transformer.visual"]
|
||||
inside_layer_modules = [
|
||||
["attn.c_attn"],
|
||||
["attn.c_proj"],
|
||||
["mlp.w1", "mlp.w2"],
|
||||
["mlp.c_proj"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["QwenGPTQForCausalLM"]
|
|
@ -1,15 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "DecoderLayer"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
|
||||
inside_layer_modules = [
|
||||
["self_attention.query_key_value"],
|
||||
["self_attention.dense"],
|
||||
["mlp.dense_h_to_4h"],
|
||||
["mlp.dense_4h_to_h"]
|
||||
]
|
||||
|
||||
__all__ = ["RWGPTQForCausalLM"]
|
|
@ -1,42 +0,0 @@
|
|||
from abc import abstractmethod
|
||||
from logging import getLogger
|
||||
|
||||
import torch.nn as nn
|
||||
from .triton_utils.mixin import TritonModuleMixin
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FusedBaseModule(nn.Module, TritonModuleMixin):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def inject_to_model(cls, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FusedBaseAttentionModule(FusedBaseModule):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
**kwargs
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def warmup(cls, model, transpose=False, seqlen=2048):
|
||||
pass
|
||||
|
||||
|
||||
class FusedBaseMLPModule(FusedBaseModule):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def inject_to_model(cls, model, use_triton=False, **kwargs):
|
||||
raise NotImplementedError()
|
|
@ -1,303 +0,0 @@
|
|||
from typing import *
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.models.gptj.modeling_gptj import GPTJAttention
|
||||
|
||||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
if seq_len is None:
|
||||
seq_len = x.shape[seq_dim]
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
||||
sinusoid_inp = (
|
||||
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
|
||||
)
|
||||
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
||||
|
||||
|
||||
def rotate_every_two(x):
|
||||
x1 = x[:, :, :, ::2]
|
||||
x2 = x[:, :, :, 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
|
||||
def duplicate_interleave(m):
|
||||
"""
|
||||
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
||||
"""
|
||||
dim0 = m.shape[0]
|
||||
m = m.view(-1, 1) # flatten the matrix
|
||||
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
||||
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
||||
return m
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, sincos, offset=0):
|
||||
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
|
||||
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
||||
return (x * cos) + (rotate_every_two(x) * sin)
|
||||
|
||||
|
||||
class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.attn_dropout_p = config.attn_pdrop
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_attention_heads
|
||||
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
||||
f" `num_attention_heads`: {self.num_attention_heads})."
|
||||
)
|
||||
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
|
||||
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.rotary_dim = config.rotary_dim
|
||||
|
||||
def _split_heads(self, qkv):
|
||||
"""
|
||||
Splits hidden dim into attn_head_size and num_attention_heads
|
||||
"""
|
||||
new_shape = qkv.size()[:-1] + (3, self.num_attention_heads, self.head_dim)
|
||||
qkv = qkv.view(new_shape) # (batch, seq_length, 3, head, head_features)
|
||||
query = qkv[:, :, 0]
|
||||
key = qkv[:, :, 1]
|
||||
value = qkv[:, :, 2]
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
if len(tensor.shape) == 5:
|
||||
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
|
||||
elif len(tensor.shape) == 4:
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
else:
|
||||
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
|
||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def _attn(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
):
|
||||
# compute causal mask from causal mask buffer
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
query = query.to(torch.float32)
|
||||
key = key.to(torch.float32)
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
attn_weights = attn_weights / self.scale_attn
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||
]:
|
||||
query, key, value = self._split_heads(self.qkv_proj(hidden_states))
|
||||
|
||||
seq_len = key.shape[1]
|
||||
offset = 0
|
||||
|
||||
if layer_past is not None:
|
||||
offset = layer_past[0].shape[-2]
|
||||
seq_len += offset
|
||||
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
k_pass = key[:, :, :, self.rotary_dim:]
|
||||
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim:]
|
||||
|
||||
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
||||
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
||||
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
||||
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
||||
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
is_causal = layer_past is None
|
||||
if layer_past is not None:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# compute self-attention: V x Softmax(QK^T)
|
||||
if compare_pytorch_version("v2.0.0", op="ge"):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None if is_causal else attention_mask,
|
||||
dropout_p=self.attn_dropout_p,
|
||||
is_causal=is_causal
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
bits: int = 4,
|
||||
disable_exllama=True,
|
||||
disable_exllamav2=False,
|
||||
**kwargs
|
||||
):
|
||||
config = model.config
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, GPTJAttention):
|
||||
continue
|
||||
|
||||
attn = cls(config).to(device=next(m.buffers()).device)
|
||||
|
||||
q_proj = m.q_proj
|
||||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
qlinear_args = (
|
||||
q_proj.bits,
|
||||
q_proj.group_size,
|
||||
q_proj.infeatures,
|
||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
qkv_proj.qweight = qweights
|
||||
qkv_proj.qzeros = qzeros
|
||||
qkv_proj.scales = scales
|
||||
qkv_proj.g_idx = g_idx
|
||||
qkv_proj.bias = bias
|
||||
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
child_name = name[len(parent_name) + 1:]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ''
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
attn.qkv_proj = qkv_proj
|
||||
attn.out_proj = m.out_proj
|
||||
|
||||
setattr(parent, child_name, attn)
|
||||
del m
|
||||
|
||||
|
||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
|
@ -1,203 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
|
||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
if self.head_dim * num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.qkv_proj = qkv_proj
|
||||
self.o_proj = o_proj
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
def _shape(self, tensor, seq_len, bsz):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
**kwargs
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
is_causal = past_key_value is None
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
if use_cache:
|
||||
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
|
||||
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if compare_pytorch_version("v2.0.0", op="ge"):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=None if is_causal else attention_mask,
|
||||
is_causal=is_causal
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(
|
||||
cls,
|
||||
model,
|
||||
use_triton=False,
|
||||
group_size=-1,
|
||||
use_cuda_fp16=True,
|
||||
desc_act=False,
|
||||
trainable=False,
|
||||
bits: int = 4,
|
||||
disable_exllama=True,
|
||||
disable_exllamav2=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaAttention):
|
||||
continue
|
||||
|
||||
q_proj = m.q_proj
|
||||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
qlinear_args = (
|
||||
q_proj.bits,
|
||||
q_proj.group_size,
|
||||
q_proj.infeatures,
|
||||
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
|
||||
True if q_proj.bias is not None else False,
|
||||
)
|
||||
qlinear_kwargs = {"trainable": trainable}
|
||||
if (not desc_act or group_size == -1) and not use_triton:
|
||||
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
|
||||
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.qzeros = qzeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.g_idx = g_idx
|
||||
qkv_layer.bias = bias
|
||||
|
||||
attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
||||
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
child_name = name[len(parent_name) + 1:]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ''
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, attn)
|
||||
|
||||
|
||||
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
|
|
@ -1,330 +0,0 @@
|
|||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
|
||||
from ._fused_base import FusedBaseMLPModule
|
||||
from ..utils.import_utils import TRITON_AVAILABLE
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .triton_utils import custom_autotune
|
||||
from .triton_utils.kernels import silu
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 256,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
), # 3090
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 16,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
), # 3090
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4
|
||||
), # 3090
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 16,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
), # 3090
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
), # 3090
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def quant_fused_matmul_248_kernel(
|
||||
a_ptr, c_ptr, b1_ptr,
|
||||
scales1_ptr, zeros1_ptr,
|
||||
g1_ptr, b2_ptr,
|
||||
scales2_ptr, zeros2_ptr,
|
||||
g2_ptr,
|
||||
M, N, K,
|
||||
bits, maxq,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Computes: C = silu(A * B1) * (A * B2)
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (1, N) float16
|
||||
zeros is of shape (1, N//8) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
g1_ptrs = g1_ptr + offs_k
|
||||
g2_ptrs = g2_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales1_ptrs = scales1_ptr + offs_bn[None, :]
|
||||
scales2_ptrs = scales2_ptr + offs_bn[None, :]
|
||||
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, num_pid_k):
|
||||
g1_idx = tl.load(g1_ptrs)
|
||||
g2_idx = tl.load(g2_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
|
||||
|
||||
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
|
||||
zeros1 = (zeros1 + 1)
|
||||
|
||||
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
|
||||
zeros2 = (zeros2 + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
b2 = tl.load(b2_ptrs)
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b1 = (b1 - zeros1) * scales1 # Scale and shift
|
||||
accumulator1 += tl.dot(a, b1)
|
||||
|
||||
b2 = (b2 >> shifter[:, None]) & maxq
|
||||
b2 = (b2 - zeros2) * scales2
|
||||
accumulator2 += tl.dot(a, b2)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g1_ptrs += BLOCK_SIZE_K
|
||||
g2_ptrs += BLOCK_SIZE_K
|
||||
|
||||
accumulator1 = silu(accumulator1)
|
||||
c = accumulator1 * accumulator2
|
||||
c = c.to(tl.float16)
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
else:
|
||||
quant_fused_matmul_248_kernel = None
|
||||
|
||||
|
||||
class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
|
||||
def __init__(
|
||||
self,
|
||||
gate_proj,
|
||||
down_proj,
|
||||
up_proj,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.infeatures = gate_proj.infeatures
|
||||
self.intermediate_size = gate_proj.outfeatures
|
||||
self.outfeatures = down_proj.outfeatures
|
||||
self.bits = gate_proj.bits
|
||||
self.maxq = gate_proj.maxq
|
||||
|
||||
self.gate_proj = gate_proj
|
||||
self.up_proj = up_proj
|
||||
self.down_proj = down_proj
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.triton_llama_mlp(x))
|
||||
|
||||
def triton_llama_mlp(self, x):
|
||||
with torch.cuda.device(x.device):
|
||||
out_shape = x.shape[:-1] + (self.intermediate_size, )
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
M, K = x.shape
|
||||
N = self.intermediate_size
|
||||
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
quant_fused_matmul_248_kernel[grid](
|
||||
x, c, self.gate_proj.qweight,
|
||||
self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,
|
||||
self.up_proj.qweight,
|
||||
self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,
|
||||
M, N, K,
|
||||
self.bits, self.maxq,
|
||||
x.stride(0), x.stride(1),
|
||||
self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)
|
||||
)
|
||||
c = c.reshape(out_shape)
|
||||
return c
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, use_triton=False, **kwargs):
|
||||
if not use_triton:
|
||||
logger.warning(f"skip module injection for {cls.__name__} not support integrate without triton yet.")
|
||||
return
|
||||
elif not TRITON_AVAILABLE:
|
||||
logger.warning(f"skip module injection for triton is not installed.")
|
||||
return
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaMLP):
|
||||
continue
|
||||
|
||||
mlp = cls(m.gate_proj, m.down_proj, m.up_proj)
|
||||
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
child_name = name[len(parent_name) + 1:]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ''
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, mlp)
|
||||
|
||||
@classmethod
|
||||
def warmup(cls, model, transpose=False, seqlen=2048):
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, cls):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.intermediate_size
|
||||
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = m
|
||||
|
||||
logger.info(f'Found {len(kn_values)} unique fused mlp KN values.')
|
||||
logger.info('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
||||
m = 2 ** m
|
||||
for (k, n), (modules) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
modules.triton_llama_mlp(a)
|
||||
del kn_values
|
||||
|
||||
|
||||
__all__ = ["FusedLlamaMLPForQuantizedModel"]
|
|
@ -9,40 +9,32 @@ import transformers
|
|||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
import autogptq_cuda_256
|
||||
import autogptq_cuda_64
|
||||
_autogptq_cuda_available = True
|
||||
import quant_cuda
|
||||
|
||||
_quant_cuda_available = True
|
||||
except ImportError:
|
||||
logger.warning('CUDA extension not installed.')
|
||||
autogptq_cuda_256 = None
|
||||
autogptq_cuda_64 = None
|
||||
_autogptq_cuda_available = False
|
||||
_quant_cuda_available = False
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "cuda"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits,
|
||||
group_size,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
kernel_switch_threshold=128,
|
||||
trainable=False
|
||||
):
|
||||
super().__init__()
|
||||
global _autogptq_cuda_available
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
_autogptq_cuda_available = False
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
self.register_buffer(
|
||||
|
@ -51,15 +43,15 @@ class QuantLinear(nn.Module):
|
|||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
|
@ -80,18 +72,9 @@ class QuantLinear(nn.Module):
|
|||
).reshape(1, 3, 12)
|
||||
|
||||
self.kernel_switch_threshold = kernel_switch_threshold
|
||||
self.autogptq_cuda_available = _autogptq_cuda_available
|
||||
|
||||
self.autogptq_cuda = autogptq_cuda_256
|
||||
self.quant_cuda_available = _quant_cuda_available
|
||||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||
self.autogptq_cuda = autogptq_cuda_64
|
||||
if infeatures % 64 != 0 or outfeatures % 64 != 0:
|
||||
self.autogptq_cuda_available = False
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
self.quant_cuda_available = False
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
|
@ -196,20 +179,21 @@ class QuantLinear(nn.Module):
|
|||
def forward(self, x: torch.Tensor):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.autogptq_cuda_available and (
|
||||
if self.quant_cuda_available and (
|
||||
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
|
||||
):
|
||||
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
|
||||
if self.bits == 2:
|
||||
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
elif self.bits == 3:
|
||||
self.autogptq_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
elif self.bits == 4:
|
||||
self.autogptq_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
elif self.bits == 8:
|
||||
self.autogptq_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
out = out.half()
|
||||
else:
|
||||
if self.wf.device != self.qzeros.device:
|
||||
self.wf = self.wf.to(self.qzeros.device)
|
||||
|
@ -219,7 +203,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||
self.wf.unsqueeze(0)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
@ -228,7 +212,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||
self.wf.unsqueeze(-1)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1)
|
||||
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(
|
||||
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
|
||||
|
@ -254,23 +238,12 @@ class QuantLinear(nn.Module):
|
|||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
num_itr = self.g_idx.shape[0]//x.shape[-1]
|
||||
if num_itr == 1:
|
||||
|
||||
weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]))
|
||||
else:
|
||||
num_dim = self.g_idx.shape[0]//num_itr
|
||||
weights = []
|
||||
for i in range(num_itr):
|
||||
scale_i = self.scales[:,i*num_dim:(i+1)*num_dim]
|
||||
weight_i = weight[:,i*num_dim:(i+1)*num_dim]
|
||||
zeros_i = zeros[:,i*num_dim:(i+1)*num_dim]
|
||||
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
|
||||
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
|
||||
weights = torch.cat(weights,dim=1)
|
||||
out = torch.matmul(x.to(weights.dtype), weights)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = torch.matmul(x.half(), weights)
|
||||
out = out.reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
|
@ -1,56 +0,0 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class GeneralQuantLinear(nn.Linear):
|
||||
def __init__(self, quant_linear_module):
|
||||
super().__init__(
|
||||
in_features=quant_linear_module.infeatures,
|
||||
out_features=quant_linear_module.outfeatures,
|
||||
bias=True
|
||||
)
|
||||
self.infeatures = quant_linear_module.infeatures
|
||||
self.outfeatures = quant_linear_module.outfeatures
|
||||
self.bits = quant_linear_module.bits
|
||||
self.group_size = quant_linear_module.group_size
|
||||
self.maxq = quant_linear_module.maxq
|
||||
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.weight.data = quant_linear_module.qweight
|
||||
self.register_buffer('qweight', quant_linear_module.qweight)
|
||||
self.bias.data = quant_linear_module.bias
|
||||
|
||||
self.qweight.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
|
||||
self.register_buffer('qzeros', quant_linear_module.qzeros)
|
||||
self.register_buffer('scales', quant_linear_module.scales)
|
||||
self.register_buffer('g_idx', quant_linear_module.g_idx)
|
||||
|
||||
if hasattr(quant_linear_module, "wf"):
|
||||
self.wf = quant_linear_module.wf
|
||||
if hasattr(quant_linear_module, "kernel_switch_threshold"):
|
||||
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
|
||||
if hasattr(quant_linear_module, "autogptq_cuda_available"):
|
||||
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
|
||||
|
||||
self.trainable = quant_linear_module.trainable
|
||||
|
||||
self.forward = quant_linear_module.forward
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, target_module_type):
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, target_module_type):
|
||||
continue
|
||||
new_m = cls(m)
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
child_name = name[len(parent_name) + 1:]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ''
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, new_m)
|
|
@ -1,171 +0,0 @@
|
|||
# Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import numpy as np
|
||||
import transformers
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from exllama_kernels import make_q4, q4_matmul
|
||||
except ImportError:
|
||||
logger.error('exllama_kernels not installed.')
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx if g_idx is not None else none_tensor,
|
||||
device)
|
||||
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllama"
|
||||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||
super().__init__()
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllama kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||
if trainable:
|
||||
raise NotImplementedError("Exllama kernel does not support training.")
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.trainable = trainable
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
assert infeatures % 32 == 0
|
||||
assert infeatures % self.group_size == 0
|
||||
assert outfeatures % 32 == 0
|
||||
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def post_init(self):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
|
||||
self.width = self.qweight.shape[1]
|
||||
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx.to("cpu") if self._use_act_order else None,
|
||||
self.qweight.device.index
|
||||
)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(linear, transformers.pytorch_utils.Conv1D):
|
||||
W = W.t()
|
||||
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(
|
||||
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
|
||||
i = 0
|
||||
row = 0
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x.half(), self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
|
@ -1,188 +0,0 @@
|
|||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||
except ImportError:
|
||||
logger.error('exllamav2_kernels not installed.')
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
def _torch_device(idx):
|
||||
if idx == -1: return "cpu"
|
||||
return f"cuda:{idx}"
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
|
||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||
return output.view(output_shape)
|
||||
|
||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
"""
|
||||
Create Q matrix
|
||||
"""
|
||||
# EXL2
|
||||
# won't work as the moment because the tensors are not the same.
|
||||
if "q_weight" in w:
|
||||
w["q_scale_max"] /= 256
|
||||
w["q_perm"] = w["q_perm"].short()
|
||||
w["q_invperm"] = w["q_invperm"].short()
|
||||
return make_q_matrix(w["q_weight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
# GPTQ
|
||||
elif "qweight" in w:
|
||||
if w["scales"].dtype == torch.float:
|
||||
w["scales"] = w["scales"].half()
|
||||
|
||||
# GPTQ with g_idx (act_order)
|
||||
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
|
||||
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
return make_q_matrix(w["qweight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
temp_dq)
|
||||
# GPTQ without g_idx
|
||||
else:
|
||||
return make_q_matrix(w["qweight"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||
super().__init__()
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||
if trainable:
|
||||
raise NotImplementedError("Exllamav2 kernel does not support training.")
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
self.padding = - outfeatures % 32
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures + self.padding
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.trainable = trainable
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
assert infeatures % 32 == 0
|
||||
assert infeatures % self.group_size == 0
|
||||
assert outfeatures % 32 == 0
|
||||
|
||||
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
self.q_tensors = {
|
||||
"qweight":self.qweight,
|
||||
"qzeros":self.qzeros,
|
||||
"scales":self.scales,
|
||||
"g_idx":self.g_idx
|
||||
}
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
self.q_handle = ext_make_q_matrix(
|
||||
self.q_tensors, temp_dq
|
||||
)
|
||||
|
||||
def forward(self, x, force_cuda = False):
|
||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
def temp_dq_size(self):
|
||||
return self.infeatures * self.outfeatures * 2 + 128
|
||||
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
class ExLlamaV2DeviceTensors:
|
||||
|
||||
device_idx: int
|
||||
scratch_bytes: int
|
||||
scratch_idx: int
|
||||
scratch: torch.tensor = None
|
||||
|
||||
def __init__(self, device_idx, scratch_bytes):
|
||||
self.device_idx = device_idx
|
||||
self.scratch_bytes = scratch_bytes
|
||||
|
||||
def prepare(self):
|
||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
|
||||
|
||||
def get_scratch_slice(self, size_bytes):
|
||||
|
||||
if self.scratch is None: self.prepare()
|
||||
|
||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||
size_half = size_bytes // 2
|
||||
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
||||
return scratch_slice
|
|
@ -1,262 +0,0 @@
|
|||
from copy import deepcopy
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from gekko import GEKKO
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
except ImportError:
|
||||
logger.error('cQIGen not installed.')
|
||||
raise
|
||||
|
||||
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
||||
m = GEKKO() # create GEKKO model
|
||||
#cinfergen if bits==3:
|
||||
# tu = tu*3
|
||||
B = m.Const(value=bits)
|
||||
TP = m.Const(value=T//p)
|
||||
k = m.Var(1,integer=True,lb=1)
|
||||
z = m.Var(1,integer=True,lb=1)
|
||||
w = m.Var(1,integer=True,lb=1)
|
||||
y = m.Var(1,integer=True,lb=1)
|
||||
v = m.Var(1,integer=True,lb=1)
|
||||
mb = m.Var(mu,integer=True,lb=1)
|
||||
if gs != -1:
|
||||
gg = m.Var(1,integer=True,lb=1)
|
||||
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
|
||||
L = m.Var(integer=True,lb=0,ub=l1)
|
||||
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
|
||||
m.Equation(mb * k == M)
|
||||
if gs != -1:
|
||||
m.Equation(gs * gg == mb)
|
||||
# m.Equation(tb * z == T)
|
||||
m.Equation(tb * z == TP)
|
||||
m.Equation(mu * w == mb)
|
||||
m.Equation(tu * y == tb)
|
||||
# m.Equation(tb * v == tt)
|
||||
m.Maximize(L)
|
||||
m.options.SOLVER = 1
|
||||
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||
# minlp iterations with integer solution
|
||||
'minlp_max_iter_with_int_sol 10', \
|
||||
# treat minlp as nlp
|
||||
'minlp_as_nlp 0', \
|
||||
# nlp sub-problem max iterations
|
||||
'nlp_maximum_iterations 100', \
|
||||
# 1 = depth first, 2 = breadth first
|
||||
'minlp_branch_method 2', \
|
||||
# maximum deviation from whole number
|
||||
'minlp_integer_tol 0.00', \
|
||||
# covergence tolerance
|
||||
'minlp_gap_tol 0.01']
|
||||
try:
|
||||
m.solve(disp=False)
|
||||
except:
|
||||
try:
|
||||
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||
# minlp iterations with integer solution
|
||||
'minlp_max_iter_with_int_sol 10', \
|
||||
# treat minlp as nlp
|
||||
'minlp_as_nlp 0', \
|
||||
# nlp sub-problem max iterations
|
||||
'nlp_maximum_iterations 100', \
|
||||
# 1 = depth first, 2 = breadth first
|
||||
'minlp_branch_method 1', \
|
||||
# maximum deviation from whole number
|
||||
'minlp_integer_tol 0.00', \
|
||||
# covergence tolerance
|
||||
'minlp_gap_tol 0.01']
|
||||
m.solve(disp=False)
|
||||
except:
|
||||
# mytb = T//p
|
||||
mytb = tu
|
||||
if gs != -1:
|
||||
mymb = gs
|
||||
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
|
||||
mymb += gs
|
||||
while M % mymb != 0:
|
||||
mymb -= gs
|
||||
return (int(mymb), int(mytb))
|
||||
else:
|
||||
mymb = mu
|
||||
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
|
||||
mymb += mu
|
||||
while M % mymb != 0:
|
||||
mymb -= mu
|
||||
return (int(mymb), int(mytb))
|
||||
|
||||
return (int(mb.value[0]), int(tb.value[0]))
|
||||
|
||||
params = {}
|
||||
|
||||
def compute_reductions(x, gs=-1, cpp=True):
|
||||
if cpp:
|
||||
if len(x.shape) != 1:
|
||||
rows, cols = x.shape
|
||||
else:
|
||||
rows = 1
|
||||
cols = x.shape[0]
|
||||
if gs == -1:
|
||||
out = torch.zeros(rows).float().contiguous()
|
||||
mygs = cols
|
||||
else:
|
||||
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||
mygs = gs
|
||||
|
||||
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
|
||||
return out
|
||||
if gs == -1:
|
||||
if len(x.shape) != 1:
|
||||
return torch.sum(x,1)
|
||||
else:
|
||||
return torch.sum(x)
|
||||
else:
|
||||
if len(x.shape) != 1:
|
||||
rows, cols = x.shape
|
||||
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||
for i in range(cols // gs):
|
||||
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
|
||||
return out
|
||||
else:
|
||||
cols = x.shape[0]
|
||||
out = torch.zeros(cols // gs).float().contiguous()
|
||||
for i in range(cols // gs):
|
||||
out[i] = torch.sum(x[i*gs:(i+1)*gs])
|
||||
return out
|
||||
|
||||
def process_zeros_scales(zeros, scales, bits, M):
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != M:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != M:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
return new_zeros, new_scales
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "qigen"
|
||||
|
||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
|
||||
super().__init__()
|
||||
if bits not in [2, 4]:
|
||||
raise NotImplementedError("Only 2,4 bits are supported.")
|
||||
if trainable:
|
||||
raise NotImplementedError("Qigen kernel does not support training.")
|
||||
self.bits = bits
|
||||
pack = 32 // bits
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
|
||||
n = hint
|
||||
m = self.infeatures
|
||||
t = self.outfeatures
|
||||
|
||||
#registers for now are fixed
|
||||
if bits == 3:
|
||||
packed = 32
|
||||
unroll = 3
|
||||
nu = 1 #args.n
|
||||
mu = 32
|
||||
tu = 32
|
||||
else:
|
||||
packed = 32 // bits
|
||||
unroll = 2
|
||||
nu = 1 #args.n
|
||||
mu = 16
|
||||
tu = 32
|
||||
|
||||
nb = n # it's always small for transformers
|
||||
|
||||
global params
|
||||
if (m,t) in params:
|
||||
mb = params[(m,t)][0]
|
||||
tb = params[(m,t)][1]
|
||||
else:
|
||||
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
|
||||
params[(m,t)] = (mb,tb)
|
||||
|
||||
split = np.ones(p)
|
||||
split = split * tb
|
||||
while np.sum(split) < t:
|
||||
split = split + tb
|
||||
|
||||
idx = p - 1
|
||||
while np.sum(split) > t:
|
||||
split[idx] = split[idx] - tb
|
||||
idx = idx - 1
|
||||
|
||||
assert(np.sum(split) == t)
|
||||
|
||||
split = split.astype(int)
|
||||
self.tt = int(split[0])
|
||||
|
||||
if split[0] == split[-1]:
|
||||
self.cutoff = int(p+1)
|
||||
else:
|
||||
self.cutoff = int(idx + 1)
|
||||
|
||||
self.mb = mb #// packed
|
||||
self.tb = tb
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
self.register_buffer('bias', torch.zeros(self.outfeatures))
|
||||
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||
if bits == 4:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||
elif bits == 3:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
|
||||
elif bits == 2:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
|
||||
B = x.shape[0]
|
||||
new_x = x.T.contiguous()
|
||||
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
|
||||
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
|
||||
if self.group_size == -1:
|
||||
if self.bits == 4:
|
||||
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
elif self.bits == 2:
|
||||
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
elif self.bits == 3:
|
||||
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
else:
|
||||
if self.bits == 4:
|
||||
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
elif self.bits == 2:
|
||||
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
elif self.bits == 3:
|
||||
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
return out.reshape(out_shape)
|
|
@ -1,186 +0,0 @@
|
|||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
from ..triton_utils.mixin import TritonModuleMixin
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ..triton_utils.kernels import (
|
||||
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
|
||||
QuantLinearFunction, QuantLinearInferenceOnlyFunction
|
||||
)
|
||||
except ImportError:
|
||||
logger.error('triton not installed.')
|
||||
raise
|
||||
|
||||
|
||||
class QuantLinear(nn.Module, TritonModuleMixin):
|
||||
QUANT_TYPE = "triton"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits,
|
||||
group_size,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
trainable=False
|
||||
):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
if infeatures % 32 != 0 or outfeatures % 32 != 0:
|
||||
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(linear, transformers.pytorch_utils.Conv1D):
|
||||
W = W.t()
|
||||
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(
|
||||
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
|
||||
i = 0
|
||||
row = 0
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
|
||||
out = quant_linear_fn.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq
|
||||
)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def warmup(cls, model, transpose=False, seqlen=2048):
|
||||
"""
|
||||
Pre-tunes the quantized kernel
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, cls):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.outfeatures
|
||||
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = (m.qweight, m.scales, m.qzeros, m.g_idx, m.bits, m.maxq)
|
||||
|
||||
logger.info(f'Found {len(kn_values)} unique KN Linear values.')
|
||||
logger.info('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
||||
m = 2 ** m
|
||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||
if transpose:
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
|
||||
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
|
||||
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
del kn_values
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
|
@ -7,41 +7,34 @@ import torch.nn as nn
|
|||
import transformers
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
import autogptq_cuda_256
|
||||
import autogptq_cuda_64
|
||||
_autogptq_cuda_available = True
|
||||
import quant_cuda
|
||||
|
||||
_quant_cuda_available = True
|
||||
except ImportError:
|
||||
logger.warning('CUDA extension not installed.')
|
||||
autogptq_cuda_256 = None
|
||||
autogptq_cuda_64 = None
|
||||
_autogptq_cuda_available = False
|
||||
_quant_cuda_available = False
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "cuda-old"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits,
|
||||
group_size,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
use_cuda_fp16=True,
|
||||
kernel_switch_threshold=128,
|
||||
trainable=False
|
||||
faster=True,
|
||||
kernel_switch_threshold=128
|
||||
):
|
||||
super().__init__()
|
||||
global _autogptq_cuda_available
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
if trainable:
|
||||
_autogptq_cuda_available = False
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
self.register_buffer(
|
||||
|
@ -50,24 +43,22 @@ class QuantLinear(nn.Module):
|
|||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
self.half_indim = self.infeatures // 2
|
||||
|
||||
self.use_cuda_fp16 = use_cuda_fp16 if bits != 8 else False
|
||||
self.faster = faster if bits != 8 else False
|
||||
|
||||
# is performed by unpacking the weights and using torch.matmul
|
||||
if self.bits in [2, 4, 8]:
|
||||
|
@ -83,25 +74,11 @@ class QuantLinear(nn.Module):
|
|||
).reshape(1, 3, 12)
|
||||
|
||||
self.kernel_switch_threshold = kernel_switch_threshold
|
||||
self.autogptq_cuda_available = _autogptq_cuda_available
|
||||
self.autogptq_cuda = autogptq_cuda_256
|
||||
self.quant_cuda_available = _quant_cuda_available
|
||||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||
self.autogptq_cuda = autogptq_cuda_64
|
||||
if infeatures % 64 != 0 or outfeatures % 64 != 0:
|
||||
self.autogptq_cuda_available = False
|
||||
|
||||
self.trainable = trainable
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
self.quant_cuda_available = False
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(linear, transformers.pytorch_utils.Conv1D):
|
||||
W = W.t()
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
|
@ -111,10 +88,10 @@ class QuantLinear(nn.Module):
|
|||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.group_size
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(W[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]
|
||||
(linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
|
@ -196,40 +173,57 @@ class QuantLinear(nn.Module):
|
|||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.autogptq_cuda_available is True and (
|
||||
if self.quant_cuda_available is True and (
|
||||
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
|
||||
):
|
||||
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
|
||||
if self.use_cuda_fp16:
|
||||
|
||||
if self.faster:
|
||||
x = x.half()
|
||||
if self.bits == 2:
|
||||
self.autogptq_cuda.vecquant2matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
|
||||
quant_cuda.vecquant2matmul_faster_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize, self.half_indim
|
||||
)
|
||||
elif self.bits == 3:
|
||||
self.autogptq_cuda.vecquant3matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
|
||||
quant_cuda.vecquant3matmul_faster_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize, self.half_indim
|
||||
)
|
||||
elif self.bits == 4:
|
||||
self.autogptq_cuda.vecquant4matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
|
||||
|
||||
quant_cuda.vecquant4matmul_faster_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize, self.half_indim
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4 bits are supported.")
|
||||
else:
|
||||
x = x.float()
|
||||
if self.bits == 2:
|
||||
self.autogptq_cuda.vecquant2matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
|
||||
quant_cuda.vecquant2matmul_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize
|
||||
)
|
||||
elif self.bits == 3:
|
||||
self.autogptq_cuda.vecquant3matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
|
||||
quant_cuda.vecquant3matmul_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize
|
||||
)
|
||||
elif self.bits == 4:
|
||||
self.autogptq_cuda.vecquant4matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
|
||||
quant_cuda.vecquant4matmul_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize
|
||||
)
|
||||
elif self.bits == 8:
|
||||
self.autogptq_cuda.vecquant8matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
|
||||
quant_cuda.vecquant8matmul_old(
|
||||
x, self.qweight, out, self.scales.float(), self.qzeros, self.groupsize
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
else:
|
||||
if self.wf.device != self.qzeros.device:
|
||||
self.wf = self.wf.to(self.qzeros.device)
|
||||
|
||||
if self.bits in [2,4,8]:
|
||||
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
if self.bits in [2, 4, 8]:
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||
self.wf.unsqueeze(0)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
@ -237,16 +231,21 @@ class QuantLinear(nn.Module):
|
|||
scales = self.scales
|
||||
scales = scales.reshape(-1, 1, scales.shape[-1])
|
||||
|
||||
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1)
|
||||
weight = weight.reshape(-1, self.group_size, weight.shape[2])
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||
self.wf.unsqueeze(-1)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
|
||||
weight = weight.reshape(-1, self.groupsize, weight.shape[2])
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
|
||||
zeros = self.qzeros.reshape(
|
||||
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
|
||||
).expand(-1, -1, -1, 12)
|
||||
zeros = (zeros >> self.wf.unsqueeze(0))
|
||||
zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4)
|
||||
zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6)
|
||||
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4)
|
||||
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6)
|
||||
zeros = zeros & 0x7
|
||||
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)
|
||||
zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
@ -254,22 +253,25 @@ class QuantLinear(nn.Module):
|
|||
scales = self.scales
|
||||
scales = scales.reshape(-1, 1, scales.shape[-1])
|
||||
|
||||
weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1)
|
||||
weight = (weight >> self.wf.unsqueeze(-1))&0x7
|
||||
weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4)
|
||||
weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6)
|
||||
weight = self.qweight.reshape(
|
||||
self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]
|
||||
).expand(-1, -1, 12, -1)
|
||||
weight = (weight >> self.wf.unsqueeze(-1)) & 0x7
|
||||
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4)
|
||||
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6)
|
||||
weight = weight & 0x7
|
||||
weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1)
|
||||
weight = weight.reshape(-1, self.group_size, weight.shape[2])
|
||||
weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1)
|
||||
weight = weight.reshape(-1, self.groupsize, weight.shape[2])
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
weight = (scales * (weight - zeros))
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
out = torch.matmul(x.to(weight.dtype), weight)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = torch.matmul(x.half(), weight)
|
||||
out = out.reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
493
auto_gptq/nn_modules/qlinear_triton.py
Normal file
493
auto_gptq/nn_modules/qlinear_triton.py
Normal file
|
@ -0,0 +1,493 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .triton_utils import custom_autotune
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=2,
|
||||
num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
|
||||
num_stages=3,
|
||||
num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
|
||||
num_stages=2,
|
||||
num_warps=4
|
||||
),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
scales_ptr, zeros_ptr, g_ptr,
|
||||
M, N, K,
|
||||
bits, maxq,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@custom_autotune.autotune(configs=[
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
|
||||
num_stages=2,
|
||||
num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
|
||||
num_stages=3,
|
||||
num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
|
||||
num_stages=2,
|
||||
num_warps=4
|
||||
),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True
|
||||
)
|
||||
@triton.jit
|
||||
def transpose_matmul_248_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
scales_ptr, zeros_ptr, g_ptr,
|
||||
M, N, K,
|
||||
bits, maxq,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_n):
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
except ImportError:
|
||||
logger.warning('triton not installed.')
|
||||
raise
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales, qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], input.shape[1],
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)
|
||||
transpose_matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales, qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], output_dim,
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||
bits, maxq = ctx.bits, ctx.maxq
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias
|
||||
):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
if infeatures % 256 != 0 or outfeatures % 256 != 0:
|
||||
raise NotImplementedError("in_feature or out_feature must be divisible by 256.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
W = linear.weight.data.clone()
|
||||
if isinstance(linear, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(linear, transformers.pytorch_utils.Conv1D):
|
||||
W = W.t()
|
||||
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(
|
||||
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
|
||||
i = 0
|
||||
row = 0
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq
|
||||
)
|
||||
out = out.reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out
|
||||
|
||||
|
||||
def autotune_warmup_linear(model, transpose=False, seqlen=2048):
|
||||
"""
|
||||
Pre-tunes the quantized kernel
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, QuantLinear):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.outfeatures
|
||||
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
|
||||
|
||||
logger.info(f'Found {len(kn_values)} unique KN Linear values.')
|
||||
logger.info('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
|
||||
m = 2 ** m
|
||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
||||
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
if transpose:
|
||||
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
|
||||
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
del kn_values
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantLinear",
|
||||
"autotune_warmup_linear"
|
||||
]
|
|
@ -1,402 +0,0 @@
|
|||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from logging import getLogger
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from . import custom_autotune
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8
|
||||
)
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def quant_matmul_248_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
scales_ptr, zeros_ptr, g_ptr,
|
||||
M, N, K,
|
||||
bits, maxq,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 256,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8
|
||||
)
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True
|
||||
)
|
||||
@triton.jit
|
||||
def transpose_quant_matmul_248_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
scales_ptr, zeros_ptr, g_ptr,
|
||||
M, N, K,
|
||||
bits, maxq,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_n):
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def silu(x):
|
||||
return x * tl.sigmoid(x)
|
||||
|
||||
|
||||
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
|
||||
)
|
||||
quant_matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales.to(input.dtype), qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], input.shape[1],
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)
|
||||
transpose_quant_matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales.to(input.dtype), qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], output_dim,
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||
bits, maxq = ctx.bits, ctx.maxq
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
|
||||
)
|
||||
quant_matmul_248_kernel[grid](
|
||||
input, qweight, output,
|
||||
scales, qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], input.shape[1],
|
||||
bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return output
|
|
@ -1,4 +0,0 @@
|
|||
class TritonModuleMixin:
|
||||
@classmethod
|
||||
def warmup(cls, model, transpose=False, seqlen=2048):
|
||||
pass
|
|
@ -60,7 +60,7 @@ class GPTQ:
|
|||
self.H += inp.matmul(inp.t())
|
||||
|
||||
def fasterquant(
|
||||
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
|
||||
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False
|
||||
):
|
||||
W = self.layer.weight.data.clone()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
|
@ -80,26 +80,10 @@ class GPTQ:
|
|||
H[dead, dead] = 1
|
||||
W[:, dead] = 0
|
||||
|
||||
g_idx = []
|
||||
scale = []
|
||||
zero = []
|
||||
now_idx = 1
|
||||
|
||||
if static_groups:
|
||||
import copy
|
||||
groups = []
|
||||
for i in range(0, self.columns, group_size):
|
||||
quantizer = copy.deepcopy(self.quantizer)
|
||||
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
|
||||
scale.append(quantizer.scale)
|
||||
zero.append(quantizer.zero)
|
||||
groups.append(quantizer)
|
||||
|
||||
if actorder:
|
||||
perm = torch.argsort(torch.diag(H), descending=True)
|
||||
W = W[:, perm]
|
||||
H = H[perm][:, perm]
|
||||
invperm = torch.argsort(perm)
|
||||
|
||||
Losses = torch.zeros_like(W)
|
||||
Q = torch.zeros_like(W)
|
||||
|
@ -112,6 +96,11 @@ class GPTQ:
|
|||
H = torch.linalg.cholesky(H, upper=True)
|
||||
Hinv = H
|
||||
|
||||
g_idx = []
|
||||
scale = []
|
||||
zero = []
|
||||
now_idx = 1
|
||||
|
||||
for i1 in range(0, self.columns, blocksize):
|
||||
i2 = min(i1 + blocksize, self.columns)
|
||||
count = i2 - i1
|
||||
|
@ -126,20 +115,14 @@ class GPTQ:
|
|||
w = W1[:, i]
|
||||
d = Hinv1[i, i]
|
||||
|
||||
if group_size != -1:
|
||||
if not static_groups:
|
||||
if (i1 + i) % group_size == 0:
|
||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
|
||||
if groupsize != -1:
|
||||
if (i1 + i) % groupsize == 0:
|
||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
||||
|
||||
if ((i1 + i) // group_size) - now_idx == -1:
|
||||
if ((i1 + i) // groupsize) - now_idx == -1:
|
||||
scale.append(self.quantizer.scale)
|
||||
zero.append(self.quantizer.zero)
|
||||
now_idx += 1
|
||||
else:
|
||||
idx = i1 + i
|
||||
if actorder:
|
||||
idx = perm[idx]
|
||||
self.quantizer = groups[idx // group_size]
|
||||
|
||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||
Q1[:, i] = q
|
||||
|
@ -164,13 +147,11 @@ class GPTQ:
|
|||
logger.info(f'duration: {(time.time() - tick)}')
|
||||
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
|
||||
|
||||
group_size = group_size if group_size != -1 else self.columns
|
||||
if static_groups and actorder:
|
||||
g_idx = [perm[i] // group_size for i in range(self.columns)]
|
||||
else:
|
||||
g_idx = [i // group_size for i in range(self.columns)]
|
||||
groupsize = groupsize if groupsize != -1 else self.columns
|
||||
g_idx = [i // groupsize for i in range(self.columns)]
|
||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||
if actorder:
|
||||
invperm = torch.argsort(perm)
|
||||
Q = Q[:, invperm]
|
||||
g_idx = g_idx[invperm]
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .perplexity_utils import Perplexity
|
|
@ -1,48 +0,0 @@
|
|||
import gc
|
||||
import torch
|
||||
|
||||
def exllama_set_max_input_length(model, max_input_length: int):
|
||||
"""
|
||||
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
|
||||
|
||||
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
|
||||
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
|
||||
"""
|
||||
|
||||
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
|
||||
from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
|
||||
|
||||
if not model.quantize_config.desc_act:
|
||||
raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
|
||||
|
||||
device_to_buffers_size = {}
|
||||
for device, buffers in model.device_to_buffers.items():
|
||||
device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
|
||||
|
||||
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
|
||||
for key in list(model.device_to_buffers.keys()):
|
||||
del model.device_to_buffers[key]
|
||||
model.device_to_buffers = None
|
||||
del model.device_to_buffers
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_buffers_cuda()
|
||||
|
||||
device_to_buffers = {}
|
||||
for device, buffers_size in device_to_buffers_size.items():
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
device_to_buffers[device] = {
|
||||
"temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
|
||||
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
|
||||
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
|
||||
}
|
||||
|
||||
prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
|
||||
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
model.device_to_buffers = device_to_buffers
|
||||
|
||||
return model
|
|
@ -1,86 +0,0 @@
|
|||
from packaging.version import parse as parse_version
|
||||
from logging import getLogger
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
|
||||
TRITON_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRITON_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import autogptq_cuda_256
|
||||
import autogptq_cuda_64
|
||||
|
||||
AUTOGPTQ_CUDA_AVAILABLE = True
|
||||
except:
|
||||
AUTOGPTQ_CUDA_AVAILABLE = False
|
||||
|
||||
|
||||
try:
|
||||
import exllama_kernels
|
||||
|
||||
EXLLAMA_KERNELS_AVAILABLE = True
|
||||
except:
|
||||
EXLLAMA_KERNELS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import exllamav2_kernels
|
||||
|
||||
EXLLAMAV2_KERNELS_AVAILABLE = True
|
||||
except:
|
||||
EXLLAMAV2_KERNELS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
|
||||
QIGEN_AVAILABLE = True
|
||||
except:
|
||||
QIGEN_AVAILABLE = False
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
|
||||
if use_qigen:
|
||||
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
||||
else:
|
||||
if use_triton:
|
||||
if torch.version.hip:
|
||||
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
|
||||
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
else:
|
||||
if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
|
||||
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
|
||||
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
||||
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
||||
elif not desc_act or group_size == -1:
|
||||
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
||||
else:
|
||||
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
|
||||
|
||||
return QuantLinear
|
||||
|
||||
|
||||
def compare_transformers_version(
|
||||
version: str = "v4.28.0",
|
||||
op: str = "eq"
|
||||
):
|
||||
assert op in ["eq", "lt", "le", "gt", "ge"]
|
||||
|
||||
from transformers import __version__
|
||||
|
||||
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
|
||||
|
||||
|
||||
def compare_pytorch_version(
|
||||
version: str = "v2.0.0",
|
||||
op: str = "eq"
|
||||
):
|
||||
assert op in ["eq", "lt", "le", "gt", "ge"]
|
||||
|
||||
from torch import __version__
|
||||
|
||||
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
|
|
@ -1,423 +0,0 @@
|
|||
import warnings
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from peft import get_peft_model, PeftConfig, PeftModel, PeftType
|
||||
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
|
||||
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding
|
||||
from peft.tuners.adalora import AdaLoraConfig, AdaLoraLayer, AdaLoraModel
|
||||
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
||||
from peft.utils.other import _get_submodules
|
||||
|
||||
from ..modeling._base import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class GPTQLoraConfig(LoraConfig):
|
||||
injected_fused_attention: bool = False
|
||||
injected_fused_mlp: bool = False
|
||||
|
||||
|
||||
class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
linear_module: torch.nn.Linear,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
**kwargs,
|
||||
):
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
|
||||
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
|
||||
LoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
|
||||
|
||||
self.linear_module = linear_module
|
||||
|
||||
self.weight.requires_grad = False
|
||||
self.weight = self.linear_module.weight
|
||||
self.bias = self.linear_module.bias
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
def reset_lora_parameters(self, adapter_name):
|
||||
if adapter_name in self.lora_A.keys():
|
||||
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
|
||||
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
|
||||
|
||||
def merge(self):
|
||||
raise NotImplementedError("gptq model not support merge lora adapter")
|
||||
|
||||
def unmerge(self):
|
||||
raise NotImplementedError("gptq model not support unmerge lora adapter")
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
previous_dtype = x.dtype
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return self.linear_module(x)
|
||||
if self.disable_adapters:
|
||||
if self.r[self.active_adapter] > 0 and self.merged:
|
||||
self.unmerge()
|
||||
result = self.linear_module(x)
|
||||
elif self.r[self.active_adapter] > 0 and not self.merged:
|
||||
result = self.linear_module(x)
|
||||
|
||||
lora_B = self.lora_B[self.active_adapter]
|
||||
lora_A = self.lora_A[self.active_adapter]
|
||||
lora_dropout = self.lora_dropout[self.active_adapter]
|
||||
scale = self.scaling[self.active_adapter]
|
||||
|
||||
x = x.type_as(lora_A.weight.data)
|
||||
adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result)
|
||||
result += adapter_result
|
||||
else:
|
||||
result = self.linear_module(x)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class GPTQLoraModel(LoraModel):
|
||||
def _find_and_replace(self, adapter_name):
|
||||
lora_config = self.peft_config[adapter_name]
|
||||
is_target_modules_in_base_model = False
|
||||
kwargs = {
|
||||
"r": lora_config.r,
|
||||
"lora_alpha": lora_config.lora_alpha,
|
||||
"lora_dropout": lora_config.lora_dropout,
|
||||
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
}
|
||||
key_list = [key for key, _ in self.model.named_modules()]
|
||||
for key in key_list:
|
||||
if isinstance(lora_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(lora_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
bias = False
|
||||
if hasattr(target, "bias"):
|
||||
bias = target.bias is not None
|
||||
|
||||
if isinstance(target, LoraLayer):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
lora_config.r,
|
||||
lora_config.lora_alpha,
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Embedding):
|
||||
embedding_kwargs = kwargs.copy()
|
||||
embedding_kwargs.pop("fan_in_fan_out", None)
|
||||
in_features, out_features = target.num_embeddings, target.embedding_dim
|
||||
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and its subclasses are supported."
|
||||
)
|
||||
new_module = GPTQLoraLinear(adapter_name, target, **kwargs)
|
||||
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {lora_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
def _replace_module(self, parent_module, child_name, new_module, old_module):
|
||||
setattr(parent_module, child_name, new_module)
|
||||
if not isinstance(new_module, GPTQLoraLinear):
|
||||
new_module.weight = old_module.weight
|
||||
if hasattr(old_module, "bias"):
|
||||
if old_module.bias is not None:
|
||||
new_module.bias = old_module.bias
|
||||
|
||||
if getattr(old_module, "state", None) is not None:
|
||||
new_module.state = old_module.state
|
||||
new_module.to(old_module.weight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "lora_" in name:
|
||||
module.to(old_module.weight.device)
|
||||
|
||||
def merge_adapter(self):
|
||||
raise NotImplementedError("gptq model not support merge ada lora adapter")
|
||||
|
||||
def unmerge_adapter(self):
|
||||
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
|
||||
|
||||
def merge_and_unload(self):
|
||||
raise NotImplementedError("gptq model not support merge and unload")
|
||||
|
||||
|
||||
class GPTQAdaLoraConfig(AdaLoraConfig):
|
||||
injected_fused_attention: bool = False
|
||||
injected_fused_mlp: bool = False
|
||||
|
||||
|
||||
class GPTQSVDLinear(torch.nn.Linear, AdaLoraLayer):
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
linear_module: torch.nn.Linear,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
**kwargs,
|
||||
):
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
|
||||
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
|
||||
AdaLoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
|
||||
|
||||
self.linear_module = linear_module
|
||||
|
||||
self.weight.requires_grad = False
|
||||
self.weight = self.linear_module.weight
|
||||
self.bias = self.linear_module.bias
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
def merge(self):
|
||||
raise NotImplementedError("gptq model not support merge lora adapter")
|
||||
|
||||
def unmerge(self):
|
||||
raise NotImplementedError("gptq model not support unmerge lora adapter")
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return self.linear_module(x)
|
||||
if self.disable_adapters:
|
||||
if self.r[self.active_adapter] > 0 and self.merged:
|
||||
self.unmerge()
|
||||
result = self.linear_module(x)
|
||||
elif self.r[self.active_adapter] > 0 and not self.merged:
|
||||
result = self.linear_module(x)
|
||||
result += (
|
||||
(
|
||||
self.lora_dropout[self.active_adapter](x)
|
||||
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
|
||||
@ self.lora_B[self.active_adapter].T
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
/ (self.ranknum[self.active_adapter] + 1e-5)
|
||||
)
|
||||
else:
|
||||
result = self.linear_module(x)
|
||||
return result
|
||||
|
||||
|
||||
class GPTQAdaLoraModel(AdaLoraModel):
|
||||
def _find_and_replace(self, adapter_name):
|
||||
lora_config = self.peft_config[adapter_name]
|
||||
is_target_modules_in_base_model = False
|
||||
kwargs = {
|
||||
"r": lora_config.init_r,
|
||||
"lora_alpha": lora_config.lora_alpha,
|
||||
"lora_dropout": lora_config.lora_dropout,
|
||||
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
}
|
||||
key_list = [key for key, _ in self.model.named_modules()]
|
||||
for key in key_list:
|
||||
if isinstance(lora_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(lora_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
bias = target.bias is not None
|
||||
if isinstance(target, LoraLayer):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
lora_config.init_r,
|
||||
lora_config.lora_alpha,
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and its subclasses are supported."
|
||||
)
|
||||
new_module = GPTQSVDLinear(adapter_name, target, **kwargs)
|
||||
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {lora_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
def _replace_module(self, parent_module, child_name, new_module, old_module):
|
||||
setattr(parent_module, child_name, new_module)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "lora_" in name:
|
||||
module.to(old_module.weight.device)
|
||||
|
||||
def merge_adapter(self):
|
||||
raise NotImplementedError("gptq model not support merge ada lora adapter")
|
||||
|
||||
def unmerge_adapter(self):
|
||||
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
|
||||
|
||||
def merge_and_unload(self):
|
||||
raise NotImplementedError("gptq model not support merge and unload")
|
||||
|
||||
|
||||
def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True):
|
||||
if not ignore:
|
||||
ignore = []
|
||||
lm_head_name = model.lm_head_name
|
||||
if ignore_lm_head and lm_head_name not in ignore:
|
||||
ignore.append(lm_head_name)
|
||||
results = set()
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
res = n.split('.')[-1]
|
||||
if res not in ignore:
|
||||
results.add(res)
|
||||
return list(results)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def hijack_peft_mappings():
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
|
||||
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
|
||||
raise
|
||||
finally:
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
|
||||
|
||||
|
||||
def get_gptq_peft_model(
|
||||
model: BaseGPTQForCausalLM,
|
||||
peft_config: PeftConfig = None,
|
||||
model_id: str = None,
|
||||
adapter_name: str = "default",
|
||||
auto_find_all_linears: bool = True,
|
||||
train_mode: bool = False
|
||||
):
|
||||
if train_mode and not model.trainable:
|
||||
model.enable_trainable_mode()
|
||||
if train_mode and not peft_config:
|
||||
raise ValueError("peft_config not specified when in train mode.")
|
||||
if not train_mode and not model_id:
|
||||
raise ValueError("model_id(where to load adapters) not specified when in inference mode.")
|
||||
|
||||
if model.fused_attn_module_type is not None and not model.injected_fused_attention:
|
||||
peft_types = [PeftType.LORA.value, PeftType.ADALORA.value]
|
||||
warnings.warn(
|
||||
f"You can just ignore this warning if the peft type you use isn't in {peft_types}.\n"
|
||||
f"{model.__class__.__name__} supports injecting fused attention but not enables this time. "
|
||||
"If you are training adapters, you must also disable fused attention injection when loading quantized "
|
||||
"base model at inference time, otherwise adapters may not be added to base model properly. "
|
||||
"If you are loading adapters to do inference, you can reference to adapter's config file to check "
|
||||
"whether the adapters are trained using base model that not enable fused attention injection."
|
||||
)
|
||||
if model.injected_fused_mlp:
|
||||
raise NotImplementedError("GPTQ model that enables fused mlp injection is not supported to integrate with peft.")
|
||||
|
||||
if train_mode:
|
||||
peft_type = peft_config.peft_type
|
||||
if not isinstance(peft_type, str):
|
||||
peft_type = peft_type.value
|
||||
if peft_type in [PeftType.LORA.value, PeftType.ADALORA.value]:
|
||||
if auto_find_all_linears:
|
||||
peft_config.target_modules = find_all_linear_names(model, ignore_lm_head=True)
|
||||
if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
|
||||
peft_config = GPTQLoraConfig(**peft_config.to_dict())
|
||||
if peft_type == PeftType.ADALORA.value and not isinstance(peft_config, GPTQAdaLoraConfig):
|
||||
peft_config = GPTQAdaLoraConfig(**peft_config.to_dict())
|
||||
peft_config.injected_fused_attention = model.injected_fused_attention
|
||||
peft_config.injected_fused_mlp = model.injected_fused_mlp
|
||||
if peft_type == PeftType.ADAPTION_PROMPT.value:
|
||||
if peft_config.adapter_layers > model.config.num_hidden_layers:
|
||||
warnings.warn(
|
||||
f"model has only {model.config.num_hidden_layers} layers "
|
||||
f"but adapter_layers is set to {peft_config.adapter_layers}, "
|
||||
f"will reset value to {model.config.num_hidden_layers}."
|
||||
)
|
||||
peft_config.adapter_layers = model.config.num_hidden_layers
|
||||
if model.injected_fused_attention:
|
||||
raise NotImplementedError(
|
||||
"model with fused attention injected isn't supported to use ADAPTION_PROMPT peft type yet."
|
||||
)
|
||||
|
||||
with hijack_peft_mappings():
|
||||
try:
|
||||
if train_mode:
|
||||
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
|
||||
except:
|
||||
raise NotImplementedError(
|
||||
f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
|
||||
)
|
||||
|
||||
return peft_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPTQLoraConfig",
|
||||
"GPTQLoraModel",
|
||||
"GPTQAdaLoraConfig",
|
||||
"GPTQAdaLoraModel",
|
||||
"find_all_linear_names",
|
||||
"get_gptq_peft_model"
|
||||
]
|
|
@ -1,215 +0,0 @@
|
|||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
A class for calculating the perplexity of a language model.
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
|
||||
"""
|
||||
Calculate perplexity using the same method as seen in llama.cpp.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : AutoModelForCausalLM
|
||||
The language model for which the perplexity is calculated.
|
||||
tokenizer : AutoTokenizer
|
||||
The tokenizer corresponding to the model.
|
||||
device : str, optional
|
||||
The device to run the calculations on. If auto, the device that your model uses
|
||||
will be the device used for these calculations. Default is 'auto'.
|
||||
dataset_path : str, optional
|
||||
The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
|
||||
dataset_name : str, optional
|
||||
The name of the dataset. Default is None.
|
||||
split : str, optional
|
||||
The split of the dataset to use. Default is 'test'.
|
||||
text_column : str, optional
|
||||
The name of the column in the dataset that contains the text data. Default is 'text'.
|
||||
"""
|
||||
self._model = model
|
||||
self._tokenizer = tokenizer
|
||||
self._dataset_path = dataset_path
|
||||
self._dataset_name = dataset_name
|
||||
self._split = split
|
||||
self._text_column = text_column
|
||||
self._text = self._prepare_data()
|
||||
|
||||
def _get_device(self):
|
||||
if torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
elif torch.cuda.is_available():
|
||||
return 'cuda:0'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
def _prepare_data(self):
|
||||
"""
|
||||
Prepares the dataset by loading and formatting.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The formatted dataset as a single string.
|
||||
"""
|
||||
if self._dataset_path == 'wikitext':
|
||||
self._dataset_name = 'wikitext-2-raw-v1'
|
||||
|
||||
# Load the dataset
|
||||
data = load_dataset(self._dataset_path, self._dataset_name, split=self._split)
|
||||
# Format the text column of the dataset
|
||||
text_list = [' \n' if s == '' else s for s in data[self._text_column]]
|
||||
return ''.join(text_list)
|
||||
|
||||
@staticmethod
|
||||
def softmax(logits):
|
||||
"""
|
||||
Static method for applying the softmax function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : np.ndarray
|
||||
The input to the softmax function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The output of the softmax function.
|
||||
"""
|
||||
e_x = np.exp(logits - np.max(logits))
|
||||
return e_x / e_x.sum(axis=0)
|
||||
|
||||
def calculate_perplexity(self, n_ctx=512, n_batch=512):
|
||||
"""
|
||||
Calculates the perplexity of the language model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_ctx : int
|
||||
The context size.
|
||||
n_batch : int
|
||||
The batch size.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
The list of perplexity scores calculated.
|
||||
"""
|
||||
# Tokenize the text
|
||||
self._tokenizer.model_max_length = sys.maxsize
|
||||
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
|
||||
|
||||
nll = 0.0 # Negative log likelihood
|
||||
count = 0 # Counter for processed tokens
|
||||
curr_ppl = 0
|
||||
all_perplexity = []
|
||||
|
||||
with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress:
|
||||
for i in progress:
|
||||
# Process each batch of tokens
|
||||
nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count)
|
||||
|
||||
# Calculate and display the current perplexity
|
||||
curr_ppl = np.exp(nll / count)
|
||||
all_perplexity.append(curr_ppl)
|
||||
progress.set_description(f"Perplexity: {curr_ppl:.4f}")
|
||||
|
||||
return all_perplexity
|
||||
|
||||
def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):
|
||||
"""
|
||||
Processes each batch of tokens.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
i : int
|
||||
The batch index.
|
||||
n_ctx : int
|
||||
The context size.
|
||||
n_batch : int
|
||||
The batch size.
|
||||
tokens : torch.Tensor
|
||||
The tokenized text.
|
||||
nll : float
|
||||
The current negative log likelihood.
|
||||
count : int
|
||||
The current count of processed tokens.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The updated negative log likelihood.
|
||||
int
|
||||
The updated count of processed tokens.
|
||||
"""
|
||||
start = i * n_ctx
|
||||
end = start + n_ctx
|
||||
|
||||
num_batches = (n_ctx + n_batch - 1) // n_batch
|
||||
|
||||
logits = []
|
||||
|
||||
for j in range(num_batches):
|
||||
batch_start = start + j * n_batch
|
||||
batch_size = min(end - batch_start, n_batch)
|
||||
|
||||
token_org = tokens[0][batch_start].item()
|
||||
|
||||
if j == 0:
|
||||
# Replace the first token with the BOS token
|
||||
tokens[0][batch_start] = self._tokenizer.bos_token_id
|
||||
|
||||
# Compute the logits for the current batch of tokens
|
||||
batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)
|
||||
|
||||
tokens[0][batch_start] = token_org
|
||||
|
||||
logits.append(batch_logits)
|
||||
|
||||
# We rely on the fact that attention in the forward pass only looks at previous
|
||||
# tokens here, so the logits returned for each token are an accurate representation
|
||||
# of what the model would have predicted at that point.
|
||||
#
|
||||
# Example, we have a context window of 512, we will compute perplexity for each of the
|
||||
# last 256 tokens. Then, we split the input up into context window size chunks to
|
||||
# process the entire prompt.
|
||||
|
||||
for j in range(min(512, n_ctx // 2), n_ctx - 1):
|
||||
tok_logits = logits[0][0][j].cpu().numpy()
|
||||
# Compute the probability of the next token
|
||||
prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]
|
||||
|
||||
# Update the negative log likelihood and the count of processed tokens
|
||||
nll += -np.log(prob, where=prob>0)
|
||||
count += 1
|
||||
|
||||
return nll, count
|
||||
|
||||
def _compute_batch_logits(self, tokens, batch_start, batch_size):
|
||||
"""
|
||||
Computes the logits for a batch of tokens.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : torch.Tensor
|
||||
The tokenized text.
|
||||
batch_start : int
|
||||
The start index of the batch.
|
||||
batch_size : int
|
||||
The size of the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The logits for the batch of tokens.
|
||||
"""
|
||||
# Compute the logits without keeping track of gradients
|
||||
with torch.no_grad():
|
||||
outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
|
||||
return outputs.logits.detach()
|
|
@ -1,187 +0,0 @@
|
|||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
void vecquant2matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
);
|
||||
|
||||
void vecquant2matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
|
||||
}
|
||||
|
||||
void vecquant3matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
);
|
||||
|
||||
void vecquant3matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
|
||||
}
|
||||
|
||||
void vecquant4matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
);
|
||||
|
||||
void vecquant4matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
|
||||
}
|
||||
|
||||
void vecquant8matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
);
|
||||
|
||||
void vecquant8matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
torch::Tensor g_idx
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
|
||||
}
|
||||
|
||||
|
||||
// old
|
||||
|
||||
void vecquant2matmul_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
);
|
||||
|
||||
void vecquant2matmul_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant2matmul_cuda_old(vec, mat, mul, scales, zeros,groupsize);
|
||||
}
|
||||
|
||||
void vecquant3matmul_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
);
|
||||
|
||||
void vecquant3matmul_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant3matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
|
||||
}
|
||||
|
||||
void vecquant4matmul_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
);
|
||||
|
||||
void vecquant4matmul_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
|
||||
}
|
||||
|
||||
void vecquant8matmul_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
);
|
||||
|
||||
void vecquant8matmul_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant8matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
|
||||
}
|
||||
|
||||
void vecquant2matmul_faster_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
);
|
||||
|
||||
void vecquant2matmul_faster_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant2matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
|
||||
}
|
||||
|
||||
void vecquant3matmul_faster_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
);
|
||||
|
||||
void vecquant3matmul_faster_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant3matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
|
||||
}
|
||||
|
||||
void vecquant4matmul_faster_cuda_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
);
|
||||
|
||||
void vecquant4matmul_faster_old(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros,
|
||||
int groupsize, int vec_height
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
|
||||
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
|
||||
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
|
||||
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
|
||||
|
||||
m.def("vecquant2matmul_old", &vecquant2matmul_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant3matmul_old", &vecquant3matmul_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul_old", &vecquant4matmul_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant8matmul_old", &vecquant8matmul_old, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant2matmul_faster_old", &vecquant2matmul_faster_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA), faster version");
|
||||
m.def("vecquant3matmul_faster_old", &vecquant3matmul_faster_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version");
|
||||
m.def("vecquant4matmul_faster_old", &vecquant4matmul_faster_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version");
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,58 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_compat_cuh
|
||||
#define _cuda_compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,75 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#define _cuda_buffers_cu
|
||||
#include "cuda_buffers.cuh"
|
||||
|
||||
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||
// __constant__ half2 q4_table[16][256];
|
||||
// half2 q4_table_host[16][256];
|
||||
// bool q4_table_init = false;
|
||||
|
||||
CudaBuffers::CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
temp_state_size(_temp_state_size),
|
||||
temp_state(_temp_state),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(_device);
|
||||
|
||||
cudaStreamCreate(&alt_stream_1);
|
||||
cudaStreamCreate(&alt_stream_2);
|
||||
cudaStreamCreate(&alt_stream_3);
|
||||
cudaEventCreate(&alt_stream_1_done);
|
||||
cudaEventCreate(&alt_stream_2_done);
|
||||
cudaEventCreate(&alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers::~CudaBuffers()
|
||||
{
|
||||
cudaStreamDestroy(alt_stream_1);
|
||||
cudaStreamDestroy(alt_stream_2);
|
||||
cudaStreamDestroy(alt_stream_3);
|
||||
cudaEventDestroy(alt_stream_1_done);
|
||||
cudaEventDestroy(alt_stream_2_done);
|
||||
cudaEventDestroy(alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index)
|
||||
{
|
||||
return g_buffers[device_index];
|
||||
}
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
)
|
||||
{
|
||||
CudaBuffers* buffers = new CudaBuffers
|
||||
(
|
||||
_device,
|
||||
_temp_state_size,
|
||||
_temp_state,
|
||||
_temp_dq
|
||||
);
|
||||
|
||||
g_buffers[_device] = buffers;
|
||||
}
|
||||
|
||||
void cleanup_buffers_cuda()
|
||||
{
|
||||
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||
{
|
||||
if (!g_buffers[i]) continue;
|
||||
delete g_buffers[i];
|
||||
g_buffers[i] = NULL;
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_buffers_cuh
|
||||
#define _cuda_buffers_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
const int CUDA_MAX_DEVICES = 16;
|
||||
|
||||
// #ifndef _cuda_buffers_cu
|
||||
// extern __constant__ half2 q4_table[16][256];
|
||||
// #endif
|
||||
|
||||
class CudaBuffers
|
||||
{
|
||||
public:
|
||||
int device;
|
||||
|
||||
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||
int temp_state_size;
|
||||
half* temp_dq; // size of largest quant tensor * 8
|
||||
|
||||
cudaStream_t alt_stream_1;
|
||||
cudaStream_t alt_stream_2;
|
||||
cudaStream_t alt_stream_3;
|
||||
cudaEvent_t alt_stream_1_done;
|
||||
cudaEvent_t alt_stream_2_done;
|
||||
cudaEvent_t alt_stream_3_done;
|
||||
|
||||
CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
~CudaBuffers();
|
||||
};
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index);
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
void cleanup_buffers_cuda();
|
||||
|
||||
#endif
|
|
@ -1,63 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
|
||||
const int SHUF_BLOCKSIZE_X = 256;
|
||||
const int SHUF_BLOCKSIZE_Y = 16;
|
||||
|
||||
__global__ void column_remap_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
half* __restrict__ x_new,
|
||||
const int x_width,
|
||||
const int x_height,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||
if (x_column >= x_width) return;
|
||||
//if (x_row >= x_height) return;
|
||||
|
||||
int x_stride = x_width;
|
||||
int x_idx = x_row * x_stride + x_column;
|
||||
|
||||
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||
int x_idx_end = x_row_end * x_stride + x_column;
|
||||
|
||||
int s_column = x_map[x_column];
|
||||
int s_idx = x_row * x_stride + s_column;
|
||||
|
||||
while (x_idx < x_idx_end)
|
||||
{
|
||||
x_new[x_idx] = x[s_idx];
|
||||
x_idx += x_stride;
|
||||
s_idx += x_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap columns in x to correspond to sequential group index before matmul
|
||||
//
|
||||
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||
1
|
||||
);
|
||||
|
||||
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _column_remap_cuh
|
||||
#define _column_remap_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
);
|
||||
|
||||
#endif
|
|
@ -1,260 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
#include "../cu_compat.cuh"
|
||||
#include "../cuda_buffers.cuh"
|
||||
#if defined(USE_ROCM)
|
||||
#include "../hip_compat.cuh"
|
||||
#endif
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
half*,
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint32_t*,
|
||||
bool
|
||||
);
|
||||
|
||||
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
__global__ void q4_matmul_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int dim,
|
||||
const int width,
|
||||
const int groupsize,
|
||||
const int block_size_z,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_half x_(x, height, dim);
|
||||
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||
MatrixView_q4_column w_(w, dim, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
|
||||
// Zero output
|
||||
|
||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
half2 acc = {};
|
||||
half acc_h = {};
|
||||
|
||||
if constexpr (use_groupsize)
|
||||
{
|
||||
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||
// could be slightly faster
|
||||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to block result
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
{
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||
}
|
||||
}
|
||||
|
||||
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||
{
|
||||
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
if (tuningParams->matmul_no_half2) {
|
||||
if (block_size_z % groupsize == 0) {
|
||||
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||
else return q4_matmul_kernel<false, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||
else return q4_matmul_kernel<false, false, false>;
|
||||
}
|
||||
} else {
|
||||
if (block_size_z % groupsize == 0)
|
||||
{
|
||||
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||
else return q4_matmul_kernel<true, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||
else return q4_matmul_kernel<true, false, false>;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute y = x @ w
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
|
||||
uint32_t* x_map = w->cuda_x_map;
|
||||
const half* x_mapped = x;
|
||||
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||
{
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
x_map = NULL;
|
||||
}
|
||||
|
||||
int block_size_z;
|
||||
if (w->width == 4096) block_size_z = 384; // 7B
|
||||
else if (w->width == 11008) block_size_z = 256;
|
||||
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||
else if (w->width == 13824) block_size_z = 256;
|
||||
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||
else if (w->width == 17920) block_size_z = 128;
|
||||
else block_size_z = 256;
|
||||
|
||||
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||
|
||||
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height + threads.y - 1) / threads.y,
|
||||
(dim + block_size_z - 1) / block_size_z
|
||||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
|
||||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
||||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
const float alpha = 1.0f;
|
||||
const float beta = no_zero ? 1.0f : 0.0f;
|
||||
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
#else
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
#endif
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matmul_cuh
|
||||
#define _q4_matmul_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include "../tuning.h"
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#endif
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero = false,
|
||||
cudaStream_t alt_stream = NULL
|
||||
);
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero = false
|
||||
);
|
||||
|
||||
#endif
|
|
@ -1,225 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include <vector>
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||
|
||||
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||
|
||||
vector<Q4Matrix*> g_q4_matrices;
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m)
|
||||
{
|
||||
g_q4_matrices.push_back(m);
|
||||
}
|
||||
|
||||
void g_q4_free_matrices()
|
||||
{
|
||||
for (const auto& m : g_q4_matrices) delete m;
|
||||
g_q4_matrices.clear();
|
||||
}
|
||||
|
||||
Q4Matrix::Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
) :
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
device(_device)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_qweight = _qweight;
|
||||
cuda_qzeros = _qzeros;
|
||||
cuda_scales = _scales;
|
||||
|
||||
groupsize = height / groups;
|
||||
|
||||
if (_g_idx) make_sequential(_g_idx);
|
||||
}
|
||||
|
||||
Q4Matrix::~Q4Matrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Make sequential
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int x_map_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = x_map[x_map_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||
dim3 blocks
|
||||
(
|
||||
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
|
||||
height / 8,
|
||||
1
|
||||
);
|
||||
|
||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out, // (y)
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int width,
|
||||
const int groupsize
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||
|
||||
// Groupsize version
|
||||
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
void Q4Matrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height / 8 + threads.y - 1) / threads.y,
|
||||
1
|
||||
);
|
||||
|
||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matrix_cuh
|
||||
#define _q4_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
class Q4Matrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
uint32_t* cuda_qweight = NULL;
|
||||
uint32_t* cuda_qzeros = NULL;
|
||||
half* cuda_scales = NULL;
|
||||
uint32_t* cuda_x_map = NULL;
|
||||
|
||||
Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
);
|
||||
|
||||
~Q4Matrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
|
||||
private:
|
||||
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
};
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
|
@ -1,255 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "cuda_func/q4_matrix.cuh"
|
||||
#include "cuda_func/q4_matmul.cuh"
|
||||
#include "cuda_func/column_remap.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
// buffer size used for sanity checks
|
||||
temp_state.numel(),
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _hip_compat_cuh
|
||||
#define _hip_compat_cuh
|
||||
|
||||
// Workaround for a bug in hipamd, backported from upstream.
|
||||
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
return __half_raw{
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
#define h2rcp __compat_h2rcp
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_get_stream hipblasGetStream
|
||||
#define rocblas_set_stream hipblasSetStream
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
|
||||
#endif
|
|
@ -1,294 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _matrix_cuh
|
||||
#define _matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||
|
||||
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(*h_ptr++, v_0);
|
||||
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,13 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _tuning_h
|
||||
#define _tuning_h
|
||||
|
||||
struct ExLlamaTuning
|
||||
{
|
||||
int matmul_recons_thd;
|
||||
bool matmul_fused_remap;
|
||||
bool matmul_no_half2;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,33 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define cudaUnspecified hipErrorUnknown
|
||||
#else
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
#endif
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
#define _cuda_check(fn) \
|
||||
do { \
|
||||
{_cuda_err = fn;} \
|
||||
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||
} while(false)
|
||||
|
||||
// React to failure on return code == 0
|
||||
|
||||
#define _alloc_check(fn) \
|
||||
do { \
|
||||
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||
else _cuda_err = cudaSuccess; \
|
||||
} while(false)
|
||||
|
||||
#endif
|
|
@ -1,13 +0,0 @@
|
|||
#ifndef _config_h
|
||||
#define _config_h
|
||||
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
|
||||
#define QMODE_2BIT 1
|
||||
#define QMODE_3BIT 1
|
||||
#define QMODE_4BIT 1
|
||||
#define QMODE_5BIT 1
|
||||
#define QMODE_6BIT 0
|
||||
#define QMODE_8BIT 0
|
||||
|
||||
#endif
|
|
@ -1,12 +0,0 @@
|
|||
#ifndef _util_h
|
||||
#define _util_h
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
|
||||
#endif
|
|
@ -1,56 +0,0 @@
|
|||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,121 +0,0 @@
|
|||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "quant/qdq_util.cuh"
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,238 +0,0 @@
|
|||
#include "q_gemm.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "../config.h"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
bool clear
|
||||
)
|
||||
{
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_q_scale,
|
||||
b->cuda_q_scale_max,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_8,
|
||||
b->rows_6,
|
||||
b->rows_5,
|
||||
b->rows_4,
|
||||
b->rows_3,
|
||||
b->rows_2,
|
||||
clear
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
// DBGX((uint64_t) b->cuda_q_perm);
|
||||
// DBGI(b->rows_4);
|
||||
// DBGI(b->height);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_gptq_qzeros,
|
||||
b->cuda_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_4,
|
||||
clear
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear,
|
||||
half* temp_dq,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||
{
|
||||
//printf("cublas\n");
|
||||
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
|
||||
if (!temp_dq) temp_dq = b->temp_dq;
|
||||
b->reconstruct(temp_dq);
|
||||
|
||||
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasSgemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasGemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n,
|
||||
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("cuda\n");
|
||||
|
||||
// Quantized matmul
|
||||
|
||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
||||
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void clear_kernel
|
||||
(
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int m = blockIdx.y;
|
||||
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
||||
if (n >= size_n) return;
|
||||
int4* c_ptr = (int4*)(c + m * size_n + n);
|
||||
*c_ptr = {};
|
||||
}
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
)
|
||||
{
|
||||
return;
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = CLEAR_N_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
gridDim.y = size_m;
|
||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
#ifndef _q_gemm_cuh
|
||||
#define _q_gemm_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q_matrix.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear = false,
|
||||
half* reconstruct = NULL,
|
||||
bool force_cuda = false
|
||||
);
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
);
|
||||
|
||||
#endif
|
|
@ -1,484 +0,0 @@
|
|||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
// Preload scales
|
||||
|
||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
||||
|
||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
for (int g = 0; g < groups_in_block; g++)
|
||||
{
|
||||
int qscales[4];
|
||||
b_q_scale_.item4(qscales, group + g, n);
|
||||
qscales[0]++;
|
||||
qscales[1]++;
|
||||
qscales[2]++;
|
||||
qscales[3]++;
|
||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
||||
}
|
||||
|
||||
// a, b offset
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int scales_idx = 0;
|
||||
float qs_f0 = scales[scales_idx][0];
|
||||
float qs_f1 = scales[scales_idx][1];
|
||||
float qs_f2 = scales[scales_idx][2];
|
||||
float qs_f3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize groups
|
||||
|
||||
int k = offset_k;
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[2];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[5];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
||||
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
||||
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
||||
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
||||
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
||||
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
||||
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
||||
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
||||
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
||||
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
// Accumulate column sums in c
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
|
@ -1,219 +0,0 @@
|
|||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_4,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
|
@ -1,603 +0,0 @@
|
|||
#include "q_matrix.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
|
||||
// Shuffle quantized data on load
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
||||
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
||||
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
||||
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
||||
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
||||
}
|
||||
|
||||
|
||||
// QMatrix constructor
|
||||
|
||||
QMatrix::QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
cuda_q_scale = _q_scale;
|
||||
cuda_q_scale_max = _q_scale_max;
|
||||
cuda_q_groups = _q_groups;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
groupsize = 1;
|
||||
while (groupsize * groups < height) groupsize *= 2;
|
||||
|
||||
// Create group map
|
||||
|
||||
rows_8 = 0;
|
||||
rows_6 = 0;
|
||||
rows_5 = 0;
|
||||
rows_4 = 0;
|
||||
rows_3 = 0;
|
||||
rows_2 = 0;
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
for (int i = 0; i < groups; i++)
|
||||
{
|
||||
int bits = cpu_q_groups[i * 2];
|
||||
if (bits == 8) rows_8 += groupsize;
|
||||
if (bits == 6) rows_6 += groupsize;
|
||||
if (bits == 5) rows_5 += groupsize;
|
||||
if (bits == 4) rows_4 += groupsize;
|
||||
if (bits == 3) rows_3 += groupsize;
|
||||
if (bits == 2) rows_2 += groupsize;
|
||||
}
|
||||
|
||||
free(cpu_q_groups);
|
||||
|
||||
rows_6 += rows_8;
|
||||
rows_5 += rows_6;
|
||||
rows_4 += rows_5;
|
||||
rows_3 += rows_4;
|
||||
rows_2 += rows_3;
|
||||
}
|
||||
else
|
||||
{
|
||||
rows_4 = height;
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
||||
}
|
||||
|
||||
// Shuffle quantized data
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
|
||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n] (GPTQ)
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_4
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n]
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
int t = threadIdx.x;
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
dequant_8bit_8(q_0, q_1, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
||||
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_4bit_8(q_0, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_2bit_16(q_0, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
void QMatrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_q_scale,
|
||||
cuda_q_scale_max,
|
||||
//cuda_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_8,
|
||||
rows_6,
|
||||
rows_5,
|
||||
rows_4,
|
||||
rows_3,
|
||||
rows_2
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_gptq_qzeros,
|
||||
cuda_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_4
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint16_t* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Reduce to uint16_t
|
||||
|
||||
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
||||
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
||||
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
||||
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_new_qweight,
|
||||
cuda_q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
#ifndef _q_matrix_cuh
|
||||
#define _q_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define MAX_SUPERGROUPS 16
|
||||
|
||||
class QMatrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
bool is_gptq;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
int rows_8;
|
||||
int rows_6;
|
||||
int rows_5;
|
||||
int rows_4;
|
||||
int rows_3;
|
||||
int rows_2;
|
||||
|
||||
uint32_t* cuda_q_weight = NULL;
|
||||
uint16_t* cuda_q_perm = NULL;
|
||||
uint16_t* cuda_q_invperm = NULL;
|
||||
uint32_t* cuda_q_scale = NULL;
|
||||
half* cuda_q_scale_max = NULL;
|
||||
uint16_t* cuda_q_groups = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,103 +0,0 @@
|
|||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_2BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
||||
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z4 = __halves2half2(z4_, z4_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,169 +0,0 @@
|
|||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_3BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
||||
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
||||
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
||||
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,227 +0,0 @@
|
|||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_4BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,207 +0,0 @@
|
|||
#ifndef _qdq_5_cuh
|
||||
#define _qdq_5_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_5BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
uint32_t qd = q[3 * stride];
|
||||
uint32_t qe = q[4 * stride];
|
||||
|
||||
// qa: 66555554 44443333 32222211 11100000
|
||||
// qb: ccccbbbb baaaaa99 99988888 77777666
|
||||
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
||||
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
||||
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
||||
|
||||
uint32_t qf = qe >> 22;
|
||||
qe <<= 8;
|
||||
qe |= qd >> 24;
|
||||
qd <<= 6;
|
||||
qd |= qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: 555554 44443333 32222211 11100000
|
||||
// qb: bbbbba aaaa9999 98888877 77766666
|
||||
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
||||
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
||||
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
uint32_t zd = 0;
|
||||
uint32_t ze = 0;
|
||||
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
||||
|
||||
// za: 5555533 33311111 4444422 22200000
|
||||
// zb: bbbbb99 99977777 aaaaa88 88866666
|
||||
// zc: hhhhhff fffddddd gggggee eeeccccc
|
||||
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
||||
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
za |= ((qf & 0x001) >> 0) << 15;
|
||||
zb |= ((qf & 0x002) >> 1) << 15;
|
||||
zc |= ((qf & 0x004) >> 2) << 15;
|
||||
zd |= ((qf & 0x008) >> 3) << 15;
|
||||
ze |= ((qf & 0x010) >> 4) << 15;
|
||||
za |= ((qf & 0x020) >> 5) << 31;
|
||||
zb |= ((qf & 0x040) >> 6) << 31;
|
||||
zc |= ((qf & 0x080) >> 7) << 31;
|
||||
zd |= ((qf & 0x100) >> 8) << 31;
|
||||
ze |= ((qf & 0x200) >> 9) << 31;
|
||||
|
||||
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
q[3 * stride] = zd;
|
||||
q[4 * stride] = ze;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
||||
const half2 y32 = __halves2half2(y32_, y32_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
||||
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z32 = __halves2half2(z32_, z32_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
uint32_t qd = q_3;
|
||||
uint32_t qe = q_4;
|
||||
|
||||
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
||||
qa >>= 10;
|
||||
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
qa >>= 5;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
||||
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
||||
qb >>= 10;
|
||||
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
||||
qb >>= 4;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
||||
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
||||
qc >>= 10;
|
||||
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
||||
qc >>= 3;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
||||
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
||||
qd >>= 10;
|
||||
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
||||
qd >>= 2;
|
||||
qd &= 0x00080008;
|
||||
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
||||
qe >>= 10;
|
||||
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
||||
qe >>= 1;
|
||||
qe &= 0x00100010;
|
||||
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hadd2( q3.as_half2, z1);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hadd2( q6.as_half2, z1);
|
||||
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
||||
dq[ 8] = __hadd2( q8.as_half2, z1);
|
||||
dq[ 9] = __hadd2( q9.as_half2, z1);
|
||||
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
||||
dq[11] = __hadd2(q11.as_half2, z1);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
||||
dq[14] = __hadd2(q14.as_half2, z1);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
||||
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
||||
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
||||
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
||||
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,44 +0,0 @@
|
|||
#ifndef _qdq_6_cuh
|
||||
#define _qdq_6_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_6BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_6bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_6bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
||||
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
||||
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
||||
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
||||
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_8BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,51 +0,0 @@
|
|||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,32 +0,0 @@
|
|||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
||||
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
||||
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
||||
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
||||
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
||||
|
||||
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
||||
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
||||
|
||||
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
||||
{
|
||||
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
||||
qs_h = __hmul(qs_h, qs_h);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float clamp(float x, float a, float b)
|
||||
{
|
||||
return fmaxf(a, fminf(b, x));
|
||||
}
|
|
@ -1,134 +0,0 @@
|
|||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#include "cuda/q_matrix.cuh"
|
||||
#include "cuda/q_gemm.cuh"
|
||||
|
||||
#include "cpp/util.h"
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
|
||||
|
||||
// Quant matrix
|
||||
|
||||
uintptr_t make_q_matrix
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
torch::Tensor q_invperm,
|
||||
torch::Tensor q_scale,
|
||||
torch::Tensor q_scale_max,
|
||||
torch::Tensor q_groups,
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
|
||||
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||
|
||||
int device = q_weight.device().index();
|
||||
int width = q_weight.size(1);
|
||||
int groups;
|
||||
int height;
|
||||
|
||||
if (!q_scale.device().is_meta())
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||
groups = q_scale.size(0);
|
||||
height = q_invperm.size(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||
groups = gptq_qzeros.size(0);
|
||||
height = q_weight.size(0) * 8;
|
||||
}
|
||||
|
||||
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||
|
||||
QMatrix* m = new QMatrix
|
||||
(
|
||||
device,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
void gemm_half_q_half
|
||||
(
|
||||
torch::Tensor a,
|
||||
uintptr_t b,
|
||||
torch::Tensor c,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||
|
||||
TORCH_CHECK_DTYPE(a, kHalf);
|
||||
TORCH_CHECK_DTYPE(c, kHalf);
|
||||
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
|
||||
gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
qm,
|
||||
(half*) c.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
true,
|
||||
NULL,
|
||||
force_cuda
|
||||
);
|
||||
}
|
||||
|
||||
// Bindings
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,480 +0,0 @@
|
|||
#include<omp.h>
|
||||
#include<immintrin.h>
|
||||
#include<fstream>
|
||||
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))
|
||||
inline
|
||||
void q2gemm_gs(const float* __restrict__ input,
|
||||
const int* __restrict__ W,
|
||||
const float* __restrict__ scales,
|
||||
const float* __restrict__ zeros,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ sums,
|
||||
float* __restrict__ output,
|
||||
const int n,
|
||||
const int m,
|
||||
const int t,
|
||||
const int nb,
|
||||
const int mb,
|
||||
const int tb,
|
||||
int ogtt,
|
||||
const int gs,
|
||||
const int cutoff){
|
||||
#pragma omp parallel num_threads(8)
|
||||
{
|
||||
int tid;
|
||||
const int mu = 16;
|
||||
const int nu = 1;
|
||||
const int tu = 32;
|
||||
const int on = n / nb;
|
||||
const int om = m / mb;
|
||||
const __m256i mask = _mm256_set1_epi32(3);
|
||||
tid = omp_get_thread_num();
|
||||
int tt = ogtt;
|
||||
if(tid >= cutoff){
|
||||
tt -= tb;
|
||||
}
|
||||
const int base_output = tid >= cutoff ?
|
||||
(tid-cutoff)*tt + (tt+tb)*cutoff:
|
||||
tid*tt;
|
||||
const int base_W = tid >= cutoff ?
|
||||
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16:
|
||||
tid*tt*m/16;
|
||||
for(int j = 0; j < tt; j+=tb){
|
||||
for(int i = 0; i < on; i++) {
|
||||
for(int k = 0; k < om; k++) {
|
||||
for(int i1 = 0; i1 < nb; i1+=nu) {
|
||||
int j1 = 0;
|
||||
for(; j1 < tb-tu+1; j1+=tu) {
|
||||
for(int k1 = 0; k1 < mb; k1+=gs) {
|
||||
__m256 acc0_0 = _mm256_setzero_ps();
|
||||
__m256 acc0_8 = _mm256_setzero_ps();
|
||||
__m256 acc0_16 = _mm256_setzero_ps();
|
||||
__m256 acc0_24 = _mm256_setzero_ps();
|
||||
for(int k2 = k1; k2 < k1+gs; k2+=16)
|
||||
{
|
||||
__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]);
|
||||
__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]);
|
||||
__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]);
|
||||
__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]);
|
||||
__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]);
|
||||
__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]);
|
||||
__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]);
|
||||
__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]);
|
||||
__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]);
|
||||
__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]);
|
||||
__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]);
|
||||
__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]);
|
||||
__m256i ws0_8 = _mm256_srli_epi32(w0, 16);
|
||||
__m256i ws8_8 = _mm256_srli_epi32(w8, 16);
|
||||
__m256i ws16_8 = _mm256_srli_epi32(w16, 16);
|
||||
__m256i ws24_8 = _mm256_srli_epi32(w24, 16);
|
||||
__m256i wsa0_8= _mm256_and_si256(ws0_8, mask);
|
||||
__m256i wsa8_8= _mm256_and_si256(ws8_8, mask);
|
||||
__m256i wsa16_8= _mm256_and_si256(ws16_8, mask);
|
||||
__m256i wsa24_8= _mm256_and_si256(ws24_8, mask);
|
||||
__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8);
|
||||
__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8);
|
||||
__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8);
|
||||
__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24);
|
||||
__m256i ws0_9 = _mm256_srli_epi32(w0, 18);
|
||||
__m256i ws8_9 = _mm256_srli_epi32(w8, 18);
|
||||
__m256i ws16_9 = _mm256_srli_epi32(w16, 18);
|
||||
__m256i ws24_9 = _mm256_srli_epi32(w24, 18);
|
||||
__m256i wsa0_9= _mm256_and_si256(ws0_9, mask);
|
||||
__m256i wsa8_9= _mm256_and_si256(ws8_9, mask);
|
||||
__m256i wsa16_9= _mm256_and_si256(ws16_9, mask);
|
||||
__m256i wsa24_9= _mm256_and_si256(ws24_9, mask);
|
||||
__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9);
|
||||
__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9);
|
||||
__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9);
|
||||
__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24);
|
||||
__m256i ws0_10 = _mm256_srli_epi32(w0, 20);
|
||||
__m256i ws8_10 = _mm256_srli_epi32(w8, 20);
|
||||
__m256i ws16_10 = _mm256_srli_epi32(w16, 20);
|
||||
__m256i ws24_10 = _mm256_srli_epi32(w24, 20);
|
||||
__m256i wsa0_10= _mm256_and_si256(ws0_10, mask);
|
||||
__m256i wsa8_10= _mm256_and_si256(ws8_10, mask);
|
||||
__m256i wsa16_10= _mm256_and_si256(ws16_10, mask);
|
||||
__m256i wsa24_10= _mm256_and_si256(ws24_10, mask);
|
||||
__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10);
|
||||
__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10);
|
||||
__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10);
|
||||
__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24);
|
||||
__m256i ws0_11 = _mm256_srli_epi32(w0, 22);
|
||||
__m256i ws8_11 = _mm256_srli_epi32(w8, 22);
|
||||
__m256i ws16_11 = _mm256_srli_epi32(w16, 22);
|
||||
__m256i ws24_11 = _mm256_srli_epi32(w24, 22);
|
||||
__m256i wsa0_11= _mm256_and_si256(ws0_11, mask);
|
||||
__m256i wsa8_11= _mm256_and_si256(ws8_11, mask);
|
||||
__m256i wsa16_11= _mm256_and_si256(ws16_11, mask);
|
||||
__m256i wsa24_11= _mm256_and_si256(ws24_11, mask);
|
||||
__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11);
|
||||
__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11);
|
||||
__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11);
|
||||
__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24);
|
||||
__m256i ws0_12 = _mm256_srli_epi32(w0, 24);
|
||||
__m256i ws8_12 = _mm256_srli_epi32(w8, 24);
|
||||
__m256i ws16_12 = _mm256_srli_epi32(w16, 24);
|
||||
__m256i ws24_12 = _mm256_srli_epi32(w24, 24);
|
||||
__m256i wsa0_12= _mm256_and_si256(ws0_12, mask);
|
||||
__m256i wsa8_12= _mm256_and_si256(ws8_12, mask);
|
||||
__m256i wsa16_12= _mm256_and_si256(ws16_12, mask);
|
||||
__m256i wsa24_12= _mm256_and_si256(ws24_12, mask);
|
||||
__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12);
|
||||
__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12);
|
||||
__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12);
|
||||
__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24);
|
||||
__m256i ws0_13 = _mm256_srli_epi32(w0, 26);
|
||||
__m256i ws8_13 = _mm256_srli_epi32(w8, 26);
|
||||
__m256i ws16_13 = _mm256_srli_epi32(w16, 26);
|
||||
__m256i ws24_13 = _mm256_srli_epi32(w24, 26);
|
||||
__m256i wsa0_13= _mm256_and_si256(ws0_13, mask);
|
||||
__m256i wsa8_13= _mm256_and_si256(ws8_13, mask);
|
||||
__m256i wsa16_13= _mm256_and_si256(ws16_13, mask);
|
||||
__m256i wsa24_13= _mm256_and_si256(ws24_13, mask);
|
||||
__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13);
|
||||
__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13);
|
||||
__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13);
|
||||
__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24);
|
||||
__m256i ws0_14 = _mm256_srli_epi32(w0, 28);
|
||||
__m256i ws8_14 = _mm256_srli_epi32(w8, 28);
|
||||
__m256i ws16_14 = _mm256_srli_epi32(w16, 28);
|
||||
__m256i ws24_14 = _mm256_srli_epi32(w24, 28);
|
||||
__m256i wsa0_14= _mm256_and_si256(ws0_14, mask);
|
||||
__m256i wsa8_14= _mm256_and_si256(ws8_14, mask);
|
||||
__m256i wsa16_14= _mm256_and_si256(ws16_14, mask);
|
||||
__m256i wsa24_14= _mm256_and_si256(ws24_14, mask);
|
||||
__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14);
|
||||
__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14);
|
||||
__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14);
|
||||
__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24);
|
||||
__m256i ws0_15 = _mm256_srli_epi32(w0, 30);
|
||||
__m256i ws8_15 = _mm256_srli_epi32(w8, 30);
|
||||
__m256i ws16_15 = _mm256_srli_epi32(w16, 30);
|
||||
__m256i ws24_15 = _mm256_srli_epi32(w24, 30);
|
||||
__m256i wsa0_15= _mm256_and_si256(ws0_15, mask);
|
||||
__m256i wsa8_15= _mm256_and_si256(ws8_15, mask);
|
||||
__m256i wsa16_15= _mm256_and_si256(ws16_15, mask);
|
||||
__m256i wsa24_15= _mm256_and_si256(ws24_15, mask);
|
||||
__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15);
|
||||
__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15);
|
||||
__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15);
|
||||
__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24);
|
||||
__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]);
|
||||
__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]);
|
||||
__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]);
|
||||
__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]);
|
||||
__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]);
|
||||
__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]);
|
||||
__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]);
|
||||
__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]);
|
||||
__m256i ws0_0 = _mm256_srli_epi32(w0, 0);
|
||||
__m256i ws8_0 = _mm256_srli_epi32(w8, 0);
|
||||
__m256i ws16_0 = _mm256_srli_epi32(w16, 0);
|
||||
__m256i ws24_0 = _mm256_srli_epi32(w24, 0);
|
||||
__m256i wsa0_0= _mm256_and_si256(ws0_0, mask);
|
||||
__m256i wsa8_0= _mm256_and_si256(ws8_0, mask);
|
||||
__m256i wsa16_0= _mm256_and_si256(ws16_0, mask);
|
||||
__m256i wsa24_0= _mm256_and_si256(ws24_0, mask);
|
||||
__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0);
|
||||
__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0);
|
||||
__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0);
|
||||
__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24);
|
||||
__m256i ws0_1 = _mm256_srli_epi32(w0, 2);
|
||||
__m256i ws8_1 = _mm256_srli_epi32(w8, 2);
|
||||
__m256i ws16_1 = _mm256_srli_epi32(w16, 2);
|
||||
__m256i ws24_1 = _mm256_srli_epi32(w24, 2);
|
||||
__m256i wsa0_1= _mm256_and_si256(ws0_1, mask);
|
||||
__m256i wsa8_1= _mm256_and_si256(ws8_1, mask);
|
||||
__m256i wsa16_1= _mm256_and_si256(ws16_1, mask);
|
||||
__m256i wsa24_1= _mm256_and_si256(ws24_1, mask);
|
||||
__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1);
|
||||
__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1);
|
||||
__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1);
|
||||
__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24);
|
||||
__m256i ws0_2 = _mm256_srli_epi32(w0, 4);
|
||||
__m256i ws8_2 = _mm256_srli_epi32(w8, 4);
|
||||
__m256i ws16_2 = _mm256_srli_epi32(w16, 4);
|
||||
__m256i ws24_2 = _mm256_srli_epi32(w24, 4);
|
||||
__m256i wsa0_2= _mm256_and_si256(ws0_2, mask);
|
||||
__m256i wsa8_2= _mm256_and_si256(ws8_2, mask);
|
||||
__m256i wsa16_2= _mm256_and_si256(ws16_2, mask);
|
||||
__m256i wsa24_2= _mm256_and_si256(ws24_2, mask);
|
||||
__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2);
|
||||
__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2);
|
||||
__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2);
|
||||
__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24);
|
||||
__m256i ws0_3 = _mm256_srli_epi32(w0, 6);
|
||||
__m256i ws8_3 = _mm256_srli_epi32(w8, 6);
|
||||
__m256i ws16_3 = _mm256_srli_epi32(w16, 6);
|
||||
__m256i ws24_3 = _mm256_srli_epi32(w24, 6);
|
||||
__m256i wsa0_3= _mm256_and_si256(ws0_3, mask);
|
||||
__m256i wsa8_3= _mm256_and_si256(ws8_3, mask);
|
||||
__m256i wsa16_3= _mm256_and_si256(ws16_3, mask);
|
||||
__m256i wsa24_3= _mm256_and_si256(ws24_3, mask);
|
||||
__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3);
|
||||
__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3);
|
||||
__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3);
|
||||
__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24);
|
||||
__m256i ws0_4 = _mm256_srli_epi32(w0, 8);
|
||||
__m256i ws8_4 = _mm256_srli_epi32(w8, 8);
|
||||
__m256i ws16_4 = _mm256_srli_epi32(w16, 8);
|
||||
__m256i ws24_4 = _mm256_srli_epi32(w24, 8);
|
||||
__m256i wsa0_4= _mm256_and_si256(ws0_4, mask);
|
||||
__m256i wsa8_4= _mm256_and_si256(ws8_4, mask);
|
||||
__m256i wsa16_4= _mm256_and_si256(ws16_4, mask);
|
||||
__m256i wsa24_4= _mm256_and_si256(ws24_4, mask);
|
||||
__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4);
|
||||
__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4);
|
||||
__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4);
|
||||
__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24);
|
||||
__m256i ws0_5 = _mm256_srli_epi32(w0, 10);
|
||||
__m256i ws8_5 = _mm256_srli_epi32(w8, 10);
|
||||
__m256i ws16_5 = _mm256_srli_epi32(w16, 10);
|
||||
__m256i ws24_5 = _mm256_srli_epi32(w24, 10);
|
||||
__m256i wsa0_5= _mm256_and_si256(ws0_5, mask);
|
||||
__m256i wsa8_5= _mm256_and_si256(ws8_5, mask);
|
||||
__m256i wsa16_5= _mm256_and_si256(ws16_5, mask);
|
||||
__m256i wsa24_5= _mm256_and_si256(ws24_5, mask);
|
||||
__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5);
|
||||
__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5);
|
||||
__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5);
|
||||
__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24);
|
||||
__m256i ws0_6 = _mm256_srli_epi32(w0, 12);
|
||||
__m256i ws8_6 = _mm256_srli_epi32(w8, 12);
|
||||
__m256i ws16_6 = _mm256_srli_epi32(w16, 12);
|
||||
__m256i ws24_6 = _mm256_srli_epi32(w24, 12);
|
||||
__m256i wsa0_6= _mm256_and_si256(ws0_6, mask);
|
||||
__m256i wsa8_6= _mm256_and_si256(ws8_6, mask);
|
||||
__m256i wsa16_6= _mm256_and_si256(ws16_6, mask);
|
||||
__m256i wsa24_6= _mm256_and_si256(ws24_6, mask);
|
||||
__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6);
|
||||
__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6);
|
||||
__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6);
|
||||
__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24);
|
||||
__m256i ws0_7 = _mm256_srli_epi32(w0, 14);
|
||||
__m256i ws8_7 = _mm256_srli_epi32(w8, 14);
|
||||
__m256i ws16_7 = _mm256_srli_epi32(w16, 14);
|
||||
__m256i ws24_7 = _mm256_srli_epi32(w24, 14);
|
||||
__m256i wsa0_7= _mm256_and_si256(ws0_7, mask);
|
||||
__m256i wsa8_7= _mm256_and_si256(ws8_7, mask);
|
||||
__m256i wsa16_7= _mm256_and_si256(ws16_7, mask);
|
||||
__m256i wsa24_7= _mm256_and_si256(ws24_7, mask);
|
||||
__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7);
|
||||
__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7);
|
||||
__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7);
|
||||
__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24);
|
||||
}
|
||||
__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]);
|
||||
__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]);
|
||||
__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]);
|
||||
__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]);
|
||||
__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]);
|
||||
__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]);
|
||||
__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]);
|
||||
__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]);
|
||||
__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0);
|
||||
__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8);
|
||||
__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16);
|
||||
__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp barrier
|
||||
const int ngs = m/gs;
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < tt; j+=32){
|
||||
__m256 acc0 = _mm256_setzero_ps();
|
||||
__m256 acc8 = _mm256_setzero_ps();
|
||||
__m256 acc16 = _mm256_setzero_ps();
|
||||
__m256 acc24 = _mm256_setzero_ps();
|
||||
for (int i1 = 0; i1 < ngs; i1++){
|
||||
__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
|
||||
__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]);
|
||||
__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]);
|
||||
__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]);
|
||||
__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]);
|
||||
__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]);
|
||||
__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]);
|
||||
__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]);
|
||||
__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]);
|
||||
__m256 zs0 = _mm256_mul_ps(z0, s0);
|
||||
__m256 zs8 = _mm256_mul_ps(z8, s8);
|
||||
__m256 zs16 = _mm256_mul_ps(z16, s16);
|
||||
__m256 zs24 = _mm256_mul_ps(z24, s24);
|
||||
acc0 = _mm256_fmadd_ps(zs0, r, acc0);
|
||||
acc8 = _mm256_fmadd_ps(zs8, r, acc8);
|
||||
acc16 = _mm256_fmadd_ps(zs16, r, acc16);
|
||||
acc24 = _mm256_fmadd_ps(zs24, r, acc24);
|
||||
}
|
||||
__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]);
|
||||
__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]);
|
||||
__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]);
|
||||
__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]);
|
||||
__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]);
|
||||
__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]);
|
||||
__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]);
|
||||
__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]);
|
||||
__m256 o10 = _mm256_add_ps(o0, acc0);
|
||||
__m256 o18 = _mm256_add_ps(o8, acc8);
|
||||
__m256 o116 = _mm256_add_ps(o16, acc16);
|
||||
__m256 o124 = _mm256_add_ps(o24, acc24);
|
||||
__m256 o20 = _mm256_add_ps(o10, b0);
|
||||
__m256 o28 = _mm256_add_ps(o18, b8);
|
||||
__m256 o216 = _mm256_add_ps(o116, b16);
|
||||
__m256 o224 = _mm256_add_ps(o124, b24);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void qforward(const float* __restrict__ input,
|
||||
const int* __restrict__ W,
|
||||
const float* __restrict__ scales,
|
||||
const float* __restrict__ zeros,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ sums,
|
||||
float* __restrict__ output,
|
||||
int n,
|
||||
int m,
|
||||
int t) {
|
||||
q2gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, 1, 1024, 32, 512, 64, 9);
|
||||
}
|
||||
inline void pack_input(float* A, float* B){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 1;
|
||||
const int M = 4096;
|
||||
const int nb = 1;
|
||||
const int mb = 1024;
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int j = 0; j < M; j+=mb){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void pack_qw_inner(int* A, int* B, int cutoff){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 256;
|
||||
const int M = 4096;
|
||||
const int nb = 64;
|
||||
int mb = 32;
|
||||
for(int j = 0, tid = 0; j < M; j+=mb, tid++){
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void pack_qw(int* A, int* B){
|
||||
pack_qw_inner(A, B, 65);
|
||||
}
|
||||
inline void pack_output(float* A, float* B){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 1;
|
||||
const int M = 4096;
|
||||
const int nb = 1;
|
||||
const int mb = 32;
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int j = 0; j < M; j+=mb){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void print_parameters(){
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
outfile << 2 << "," << 1 << "," << 16 << "," << 32 << "," << 8 << "," << 8 << "," << 64 << ",";
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,149 +0,0 @@
|
|||
|
||||
def load_int(to, address, const=True):
|
||||
if const:
|
||||
return f"const __m256i {to} = _mm256_loadu_si256({address});"
|
||||
else:
|
||||
return f"__m256i {to} = _mm256_loadu_si256({address});"
|
||||
|
||||
def load_fp(to, address, const=True):
|
||||
if const:
|
||||
return f"const __m256 {to} = _mm256_loadu_ps({address});"
|
||||
else:
|
||||
return f"__m256 {to} = _mm256_loadu_ps({address});"
|
||||
|
||||
# to = a * b + c
|
||||
def vfma(to, a, b, c):
|
||||
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
|
||||
|
||||
def vsrli(to, a, b):
|
||||
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
|
||||
|
||||
def vand(to, a, b):
|
||||
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
|
||||
|
||||
def vbroadcast_fp(to, a):
|
||||
return f"const __m256 {to} = _mm256_set1_ps({a});"
|
||||
|
||||
def vbroadcast_int32(to, a):
|
||||
return f"__m256i {to} = _mm256_set1_epi32({a});"
|
||||
|
||||
def vsetzero(to):
|
||||
return f"__m256 {to} = _mm256_setzero_ps();"
|
||||
|
||||
def vcvtepi32_ps(to, a):
|
||||
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
|
||||
|
||||
def _256extractf128_ps(to, a, imm):
|
||||
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
|
||||
|
||||
def _256castps256_ps128(to, a):
|
||||
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
|
||||
|
||||
def _add_ps(to, a, b):
|
||||
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
|
||||
|
||||
def _movehl_ps(to, a, b):
|
||||
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
|
||||
|
||||
def _shuffle_ps(to, a, b, imm):
|
||||
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
|
||||
|
||||
def _cvtss_f32(to, a):
|
||||
return f"const float {to} = _mm_cvtss_f32({a});"
|
||||
|
||||
def _reduce8_acc(a, b, c, d, e, f, g, h):
|
||||
res = ""
|
||||
res += _256extractf128_ps("hi_quad0", a, 1)
|
||||
res += _256extractf128_ps("hi_quad1", b, 1)
|
||||
res += _256extractf128_ps("hi_quad2", c, 1)
|
||||
res += _256extractf128_ps("hi_quad3", d, 1)
|
||||
res += _256extractf128_ps("hi_quad4", e, 1)
|
||||
res += _256extractf128_ps("hi_quad5", f, 1)
|
||||
res += _256extractf128_ps("hi_quad6", g, 1)
|
||||
res += _256extractf128_ps("hi_quad7", h, 1)
|
||||
|
||||
res += _256castps256_ps128("lo_quad0", a)
|
||||
res += _256castps256_ps128("lo_quad1", b)
|
||||
res += _256castps256_ps128("lo_quad2", c)
|
||||
res += _256castps256_ps128("lo_quad3", d)
|
||||
res += _256castps256_ps128("lo_quad4", e)
|
||||
res += _256castps256_ps128("lo_quad5", f)
|
||||
res += _256castps256_ps128("lo_quad6", g)
|
||||
res += _256castps256_ps128("lo_quad7", h)
|
||||
|
||||
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
|
||||
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
|
||||
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
|
||||
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
|
||||
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
|
||||
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
|
||||
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
|
||||
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
|
||||
|
||||
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
|
||||
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
|
||||
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
|
||||
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
|
||||
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
|
||||
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
|
||||
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
|
||||
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
|
||||
|
||||
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
|
||||
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
|
||||
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
|
||||
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
|
||||
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
|
||||
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
|
||||
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
|
||||
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
|
||||
|
||||
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
|
||||
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
|
||||
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
|
||||
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
|
||||
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
|
||||
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
|
||||
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
|
||||
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
|
||||
|
||||
res += _add_ps("sum0", "sum_dual0", "hi0")
|
||||
res += _add_ps("sum1", "sum_dual1", "hi1")
|
||||
res += _add_ps("sum2", "sum_dual2", "hi2")
|
||||
res += _add_ps("sum3", "sum_dual3", "hi3")
|
||||
res += _add_ps("sum4", "sum_dual4", "hi4")
|
||||
res += _add_ps("sum5", "sum_dual5", "hi5")
|
||||
res += _add_ps("sum6", "sum_dual6", "hi6")
|
||||
res += _add_ps("sum7", "sum_dual7", "hi7")
|
||||
|
||||
res += _cvtss_f32(f"f{a}", "sum0")
|
||||
res += _cvtss_f32(f"f{b}", "sum1")
|
||||
res += _cvtss_f32(f"f{c}", "sum2")
|
||||
res += _cvtss_f32(f"f{d}", "sum3")
|
||||
res += _cvtss_f32(f"f{e}", "sum4")
|
||||
res += _cvtss_f32(f"f{f}", "sum5")
|
||||
res += _cvtss_f32(f"f{g}", "sum6")
|
||||
res += _cvtss_f32(f"f{h}", "sum7")
|
||||
|
||||
return res
|
||||
|
||||
acc_idx = 0
|
||||
def _reduce_add(a):
|
||||
global acc_idx
|
||||
res = ""
|
||||
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
|
||||
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
|
||||
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
|
||||
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
|
||||
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
|
||||
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
|
||||
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
|
||||
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
|
||||
acc_idx += 1
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
|
@ -1,302 +0,0 @@
|
|||
#include <iostream>
|
||||
#include "forward.h"
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))
|
||||
|
||||
void print_matrix(std::string name, float* A, int N, int M){
|
||||
std::cout<<name<<std::endl;
|
||||
for(int i = 0; i < N; i++){
|
||||
for(int j = 0; j < M; j++){
|
||||
std::cout << A[i*M+j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
|
||||
// triple loop matmul and add bias
|
||||
for (int i = 0; i < n; i++){
|
||||
for (int j = 0; j < t; j++){
|
||||
float sum = 0;
|
||||
for (int k = 0; k < m; k++){
|
||||
sum += A[i*m+k] * B[k*t+j];
|
||||
}
|
||||
C[i*t+j] += sum + bias[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_reduction(float *in, float *out, int n, int m, int gs){
|
||||
int ng;
|
||||
if(gs == -1){
|
||||
ng = 1;
|
||||
gs = m;
|
||||
}else{
|
||||
ng = m/gs;
|
||||
}
|
||||
for(int i = 0; i < n; i++){
|
||||
for(int j0 = 0; j0 < m; j0+=gs){
|
||||
int j = j0/gs;
|
||||
out[i*ng+j] = 0;
|
||||
for(int j1 = j0; j1 < j0+gs; j1++){
|
||||
out[i*ng+j] += in[i*m+j1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||
//find scales and zeros arrays
|
||||
if(gs == -1){
|
||||
gs = n;
|
||||
}
|
||||
float range = (1<<bits) - 1;
|
||||
int packed = 32 / bits;
|
||||
|
||||
for(int i0 = 0; i0 < n; i0+=gs){
|
||||
int row = i0/gs;
|
||||
for(int j = 0; j < m; j++){
|
||||
float min = A[i0*m + j];
|
||||
float max = A[i0*m + j];
|
||||
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||
min = mymin(min, A[i1*m+j]);
|
||||
max = mymax(max, A[i1*m+j]);
|
||||
}
|
||||
scales[row*m + j] = (max-min)/range;
|
||||
zeros[row*m + j ] = min;
|
||||
}
|
||||
for(int j = 0; j < m; j++){
|
||||
for (int i1 = i0; i1 < i0+gs; i1++){
|
||||
uint32_t acc = 0;
|
||||
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
|
||||
BQ[i1*m+j] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||
//find scales and zeros arrays
|
||||
if(gs == -1){
|
||||
gs = n;
|
||||
}
|
||||
float range = (1<<bits) - 1;
|
||||
int packed = 32 / bits;
|
||||
|
||||
for(int i0 = 0; i0 < n; i0+=gs){
|
||||
int row = i0/gs;
|
||||
for(int j = 0; j < m; j++){
|
||||
float min = A[i0*m + j];
|
||||
float max = A[i0*m + j];
|
||||
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||
min = mymin(min, A[i1*m+j]);
|
||||
max = mymax(max, A[i1*m+j]);
|
||||
}
|
||||
scales[row*m + j] = (max-min)/range;
|
||||
zeros[row*m + j ] = min;
|
||||
}
|
||||
for(int j = 0; j < m; j++){
|
||||
if(bits == 3){
|
||||
for (int i1 = i0; i1 < i0+gs; i1+=32){
|
||||
uint32_t acc = 0;
|
||||
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
|
||||
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
|
||||
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
|
||||
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
|
||||
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
|
||||
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
|
||||
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
|
||||
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
|
||||
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
|
||||
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
|
||||
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
|
||||
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
|
||||
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
|
||||
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
|
||||
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
|
||||
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
|
||||
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
|
||||
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
|
||||
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
|
||||
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
|
||||
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
|
||||
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
|
||||
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
|
||||
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
|
||||
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
|
||||
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
|
||||
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
|
||||
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
|
||||
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
|
||||
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
|
||||
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
|
||||
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
|
||||
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
|
||||
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
|
||||
|
||||
int acc0 = 0, acc1 = 0, acc2 = 0;
|
||||
|
||||
acc0 |= temp0;
|
||||
acc0 |= temp1;
|
||||
acc0 |= temp2;
|
||||
acc0 |= temp3;
|
||||
acc0 |= temp4;
|
||||
acc0 |= temp5;
|
||||
acc0 |= temp6;
|
||||
acc0 |= temp7;
|
||||
acc0 |= temp8;
|
||||
acc0 |= temp9;
|
||||
acc0 |= temp10_0;
|
||||
|
||||
acc1 |= temp10_1;
|
||||
acc1 |= temp11;
|
||||
acc1 |= temp12;
|
||||
acc1 |= temp13;
|
||||
acc1 |= temp14;
|
||||
acc1 |= temp15;
|
||||
acc1 |= temp16;
|
||||
acc1 |= temp17;
|
||||
acc1 |= temp18;
|
||||
acc1 |= temp19;
|
||||
acc1 |= temp20;
|
||||
acc1 |= temp21_0;
|
||||
|
||||
acc2 |= temp21_1;
|
||||
acc2 |= temp22;
|
||||
acc2 |= temp23;
|
||||
acc2 |= temp24;
|
||||
acc2 |= temp25;
|
||||
acc2 |= temp26;
|
||||
acc2 |= temp27;
|
||||
acc2 |= temp28;
|
||||
acc2 |= temp29;
|
||||
acc2 |= temp30;
|
||||
acc2 |= temp31;
|
||||
|
||||
BQ[(3*i1/32)*m+j] = acc0;
|
||||
BQ[(3*i1/32+1)*m+j] = acc1;
|
||||
BQ[(3*i1/32+2)*m+j] = acc2;
|
||||
}
|
||||
|
||||
}else{
|
||||
for (int i1 = i0; i1 < i0+gs; i1+=packed){
|
||||
uint32_t acc = 0;
|
||||
for (int i2 = i1; i2 < i1+packed; i2++){
|
||||
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||
acc = acc | (temp << (bits*(i2-i1)));
|
||||
}
|
||||
BQ[(i1/packed)*m+j] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]){
|
||||
// read n m t from args
|
||||
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
|
||||
int n = atoi(argv[1]);
|
||||
int m = atoi(argv[2]);
|
||||
int t = atoi(argv[3]);
|
||||
int bits = atoi(argv[4]);
|
||||
int gs = atoi(argv[5]);
|
||||
int ng;
|
||||
if(gs == -1){
|
||||
ng = 1;
|
||||
}else{
|
||||
ng = m/gs;
|
||||
}
|
||||
float* A = new float[n*m];
|
||||
float* AB = new float[n*m];
|
||||
float* B = new float[m*t];
|
||||
float* BQS = new float[m*t];
|
||||
float* scales = new float[t*ng];
|
||||
float* zeros = new float[t*ng];
|
||||
int* BQ = new int[m*t/8];
|
||||
int* BQB = new int[m*t/8];
|
||||
float* sums = new float[n*ng];
|
||||
float* bias = new float[t];
|
||||
float* C = new float[n*t];
|
||||
float* CB = new float[n*t];
|
||||
float* C2 = new float[n*t];
|
||||
srand(1);
|
||||
for (int i = 0; i < n*m; i++){
|
||||
A[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < t*m; i++){
|
||||
B[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < t; i++){
|
||||
bias[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < n*t; i++){
|
||||
C[i] = 0.0;
|
||||
C2[i] = 0.0;
|
||||
}
|
||||
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||
|
||||
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||
oracle_mmadd(A, BQS, bias, C, n, m, t);
|
||||
pack_input(A,AB);
|
||||
pack_qw(BQ,BQB);
|
||||
pack_output(C,CB);
|
||||
|
||||
compute_reduction(A,sums,n,m,gs);
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
|
||||
float norm = 0.0;
|
||||
for (int i = 0; i < n*t; i++){
|
||||
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
|
||||
}
|
||||
if(norm / (n*t) < 0.0001){
|
||||
int iter = 30;
|
||||
for(int _ = 0; _ < iter; _++){
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
}
|
||||
|
||||
int num_runs = 15;
|
||||
std::vector<long int> runs(num_runs);
|
||||
for(int r = 0; r < num_runs; r++){
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
for(int _ = 0; _ < iter; _++){
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
}
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
|
||||
|
||||
}
|
||||
|
||||
std::sort(runs.begin(), runs.end());
|
||||
|
||||
float cycles_final = runs[num_runs/2 + 1] / iter;
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
|
||||
print_parameters();
|
||||
outfile << cycles_final << std::endl;
|
||||
}else{
|
||||
float cycles_final = int(10e12);
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
|
||||
print_parameters();
|
||||
outfile << cycles_final << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
|
||||
def includes():
|
||||
out = " \
|
||||
#include <torch/all.h>\n \
|
||||
#include <torch/python.h>\n \
|
||||
#include <omp.h>\n \
|
||||
#include <cmath>\n \
|
||||
#include <immintrin.h>\n \
|
||||
\n \
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
|
||||
"
|
||||
return out
|
||||
|
||||
|
||||
def module(bits_list=[4, 2]):
|
||||
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
|
||||
for bits in bits_list:
|
||||
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
|
||||
|
||||
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
|
||||
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
|
||||
|
||||
# if oracle:
|
||||
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
|
||||
|
||||
|
||||
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
|
||||
|
||||
out += '}\n'
|
||||
return out
|
||||
|
||||
def quant_scalar():
|
||||
out = " \
|
||||
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
|
||||
//find scales and zeros arrays \n \
|
||||
//quantize \n \
|
||||
int pack = 32/bits;\n \
|
||||
for (int j = 0; j < m; j++){\n \
|
||||
for (int i = 0; i < n; i+=pack){\n \
|
||||
uint32_t acc = 0;\n \
|
||||
for (int ii = i; ii < i+pack; ii++){\n \
|
||||
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
|
||||
int temp = (int)ftemp;\n \
|
||||
acc = acc | (temp << (bits*(ii-i)));\n \
|
||||
}\n \
|
||||
BQ[(i/pack)*m+j] = acc;\n \
|
||||
//BQ[0] = acc;\n \
|
||||
}\n \
|
||||
}\n \
|
||||
}\n \
|
||||
\n \
|
||||
void quant_scalar_cpu(\n \
|
||||
torch::Tensor in, torch::Tensor out, \n \
|
||||
torch::Tensor scales, torch::Tensor zeros, int bits\n \
|
||||
) {\n \
|
||||
\n \
|
||||
int N = in.size(0);\n \
|
||||
int M = in.size(1);\n \
|
||||
\n \
|
||||
float* input = in.data_ptr<float>(); \n \
|
||||
float* s = scales.data_ptr<float>();\n \
|
||||
float* z = zeros.data_ptr<float>();\n \
|
||||
int* O = out.data_ptr<int>();\n \
|
||||
\n \
|
||||
quantize_scalar(input, O, s, z, N, M, bits);\n \
|
||||
\n \
|
||||
}\n"
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
bits,nu,mu,tu,unroll,p,gs,time
|
||||
4,1,16,16,1,8,-1,1.3814e+06
|
||||
4,1,16,16,2,8,-1,1.44087e+06
|
||||
4,1,16,16,4,8,-1,1.56173e+06
|
||||
4,1,16,16,8,8,-1,1.41389e+06
|
||||
3,1,16,16,5,8,-1,2.14748e+09
|
||||
2,1,16,16,1,8,-1,1.09513e+06
|
||||
2,1,16,16,2,8,-1,1.11322e+06
|
||||
2,1,16,16,4,8,-1,1.12031e+06
|
||||
2,1,16,16,8,8,-1,1.19086e+06
|
||||
4,1,16,16,1,8,64,1.69111e+06
|
||||
4,1,16,16,2,8,64,1.60056e+06
|
||||
4,1,16,16,4,8,64,1.41263e+06
|
||||
4,1,16,16,8,8,64,1.74572e+06
|
||||
3,1,16,16,5,8,64,1.48062e+06
|
||||
2,1,16,16,1,8,64,1.51234e+06
|
||||
2,1,16,16,2,8,64,1.68108e+06
|
||||
2,1,16,16,4,8,64,1.7624e+06
|
||||
2,1,16,16,8,8,64,1.69563e+06
|
||||
4,1,16,32,1,8,-1,1.24798e+06
|
||||
4,1,16,32,2,8,-1,1.58421e+06
|
||||
4,1,16,32,4,8,-1,2.10718e+06
|
||||
4,1,16,32,8,8,-1,1.54288e+06
|
||||
3,1,16,32,5,8,-1,2.14748e+09
|
||||
2,1,16,32,1,8,-1,1.55906e+06
|
||||
2,1,16,32,2,8,-1,1.58576e+06
|
||||
2,1,16,32,4,8,-1,1.57993e+06
|
||||
2,1,16,32,8,8,-1,1.80443e+06
|
||||
4,1,16,32,1,8,64,1.58354e+06
|
||||
4,1,16,32,2,8,64,1.63248e+06
|
||||
4,1,16,32,4,8,64,1.91902e+06
|
||||
4,1,16,32,8,8,64,1.9243e+06
|
||||
3,1,16,32,5,8,64,1.33812e+06
|
||||
2,1,16,32,1,8,64,1.77522e+06
|
||||
2,1,16,32,2,8,64,1.54702e+06
|
||||
2,1,16,32,4,8,64,1.78772e+06
|
||||
2,1,16,32,8,8,64,1.49612e+06
|
|
|
@ -1,13 +1,4 @@
|
|||
## <center>News or Update</center>
|
||||
|
||||
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
|
||||
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
|
||||
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
|
||||
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
|
||||
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
|
||||
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
|
||||
- 2023-05-30 - (Update) - support download/upload quantized model from/to 🤗 Hub.
|
||||
- 2023-05-27 - (Update) - Support quantization and inference for `gpt_bigcode`, `codegen` and `RefineWeb/RefineWebModel`(falcon) model types.
|
||||
- 2023-05-04 - (Update) - Support using faster cuda kernel when `not desc_act or group_size == -1`
|
||||
- 2023-04-29 - (Update) - Support loading quantized model from arbitrary quantize_config and model_basename.
|
||||
- 2023-04-28 - (Update) - Support CPU offload and quantize/inference on multiple devices, support `gpt2` type models.
|
||||
|
|
|
@ -78,9 +78,9 @@ Pretrained model's config and the quantize config will also be saved with file n
|
|||
Instead of `.from_pretrained`, you should use `.from_quantized` to load a quantized model.
|
||||
```python
|
||||
device = "cuda:0"
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device)
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, use_triton=False)
|
||||
```
|
||||
This will first read and load `quantize_config.json` in `opt-125m-4bit-128g` directory, then based on the values of `bits` and `group_size` in it, load `gptq_model-4bit-128g.bin` model file into the first visible GPU.
|
||||
This will first read and load `quantize_config.json` in `opt-125m-4bit-128g` directory, then based on the values of `bits` and `group_size` in it, load `gptq_model-4bit-128g.bin` model file into the first GPU.
|
||||
|
||||
Then you can initialize 🤗 Transformers' `TextGenerationPipeline` and do inference.
|
||||
```python
|
||||
|
|
|
@ -4,17 +4,13 @@ Welcome to the tutorial of AutoGPTQ, in this chapter, you will learn advanced mo
|
|||
## Arguments Introduction
|
||||
In previous chapter, you learned how to load model into CPU or single GPU with the two basic apis:
|
||||
- `.from_pretrained`: by default, load the whole pretrained model into CPU.
|
||||
- `.from_quantized`: by default, `auto_gptq` will automatically find the suitable way to load the quantized model.
|
||||
- if there is only single GPU and model can fit into it, will load the whole model into that GPU;
|
||||
- if there are multiple GPUs and model can fit into them, will evenly split model and load into those GPUs;
|
||||
- if model can't fit into GPU(s), will use CPU offloading.
|
||||
- `.from_quantized`: by default, load the whole quantized model into CPU, one can set `device='cuda'` to load model into a single GPU.
|
||||
|
||||
However, the default settings above may not meet many users' demands, for they want to have more control of model loading.
|
||||
However, the default settings above may not meet many users' demands, for they want to try really large models but haven't enough CPU/GPU memory.
|
||||
|
||||
Luckily, in AutoGPTQ, we provide some advanced arguments that users can tweak to manually config model loading strategy:
|
||||
- `low_cpu_mem_usage`: `bool` type argument, defaults to False, can be used both in `.from_pretrained` and `.from_quantized`, one can enable it when there is a limitation of CPU memory(by default model will be initialized in CPU) or want to load model faster.
|
||||
Luckily, in AutoGPTQ, we provide two advanced arguments that users can tweak based on the memory of hardware:
|
||||
- `max_memory`: an optional `List[Dict[Union[str, int], str]]` type argument, can be used both in `.from_pretrained` and `.from_quantized`.
|
||||
- `device_map`: an optional `Union[str, Dict[str, Union[int, str]]]` type argument, currently only be supported in `.from_quantized`.
|
||||
- `device_map`: an optional `str` type argument, currently only be supported in `.from_quantized`.
|
||||
|
||||
Before `auto-gptq`'s existence, there are many users have already used other popular tools such as [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa) to quantize their model and saved with different name without `quantize_config.json` file introduced in previous chapter.
|
||||
|
||||
|
@ -54,18 +50,16 @@ max_memory = {0: "20GIB", "cpu": "20GIB"}
|
|||
In this case, you can also load model that smaller than 40GB but the rest 20GB will be kept in CPU memory, only be collected into GPU when needed.
|
||||
|
||||
### device_map
|
||||
So far, only `.from_quantized` supports this argument.
|
||||
So far, only `.from_quantized` supports this argument. You can specify it to use pre-set model loading strategies. Because under the hood, modules in model will be mapped to different devices based on the given `max_memory`, it's more convenient to use `device_map` directly if you don't want to spend much time on calculating how much memory in each device should be use to load model.
|
||||
|
||||
You can provide a string to this argument to use pre-set model loading strategies. Current valid values are `["auto", "balanced", "balanced_low_0", "sequential"]`
|
||||
|
||||
In the simplest way, you can set `device_map='auto'` and let 🤗 Accelerate handle the device map computation. For more details of this argument, you can reference to [this document](https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
In the simplest way, you can set `device_map='auto'` and let 🤗 Accelerate handle the device map computation. For more pre-set strategies, you can reference to [this document](https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
|
||||
## Best Practice
|
||||
|
||||
### At Quantization
|
||||
It's always recommended to first consider loading the whole model into GPU(s) for it can save the time spend on transferring module's weights between CPU and GPU.
|
||||
|
||||
However, not everyone have large GPU memory. Roughly speaking, always specify the maximum memory CPU will be used to load model, then, for each GPU, you can preserve memory that can fit in 1\~2(2\~3 for the first GPU incase CPU offload used) model layers for examples' tensors and calculations in quantization, and load model weights using all others left. By this, all you need to do is a simple math based on the number of GPUs you have, the size of model weights file(s) and the number of model layers.
|
||||
However, not everyone have large GPU memory. Roughly speaking, always specify the maximum memory CPU will be used to load model, then, for each GPU, you can preserve memory that can fit in 1~2(2~3 for the first GPU incase CPU offload used) model layers for examples' tensors and calculations in quantization, and load model weights using all others left. By this, all you need to do is a simple math based on the number of GPUs you have, the size of model weights file(s) and the number of model layers.
|
||||
|
||||
### At Inference
|
||||
For inference, following this principle: always using single GPU if you can, otherwise multiple GPUs, CPU offload is the last one to consider.
|
||||
|
|
|
@ -11,11 +11,9 @@ To Execute `basic_usage.py`, using command like this:
|
|||
python basic_usage.py
|
||||
```
|
||||
|
||||
This script also showcases how to download/upload quantized model from/to 🤗 Hub, to enable those features, you can uncomment the commented codes.
|
||||
|
||||
To Execute `basic_usage_wikitext2.py`, using command like this:
|
||||
To Execute `basic_usage_with_wikitext2.py`, using command like this:
|
||||
```shell
|
||||
python basic_usage_wikitext2.py
|
||||
python basic_usage_with_wikitext2.py
|
||||
```
|
||||
> Note: There is about 0.6 ppl degrade on opt-125m model using AutoGPTQ, compared to GPTQ-for-LLaMa.
|
||||
|
||||
|
@ -62,52 +60,21 @@ CUDA_VISIBLE_DEVICES=0 python run_text_summarization_task.py --base_model_dir PA
|
|||
|
||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||
|
||||
## Benchmark
|
||||
> Commands in this chapter should be run under `benchmark` folder.
|
||||
## Push To Hub
|
||||
> Commands in this chapter should be run under `push_to_hub` folder.
|
||||
|
||||
### Generation Speed
|
||||
`generation_speed.py` script gives an example of how to benchmark the generations speed of pretrained and quantized models that `auto_gptq` supports, this benchmarks model generation speed in tokens/s metric.
|
||||
You can upload and share your quantized model to Hugging Face Hub by using `push_to_hub` function.
|
||||
|
||||
To execute this script, using command like this:
|
||||
`push_quantized_model_to_hf_hub.py` provide a simple example to upload quantized model, tokenizer and configs at once.
|
||||
|
||||
First, you need to login, run the following command in the virtual environment where Hugging Face Transformers is installed:
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python generation_speed.py --model_name_pr_path PATH/TO/MODEL/DIR
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Then run the script like this:
|
||||
```shell
|
||||
python push_quantized_model_to_hf_hub.py --quantized_model_dir PATH/TO/QUANTIZED/MODEL/DIR --tokenizer_dir PATH/TO/TOKENIZER/DIR --repo_id REPO/ID
|
||||
```
|
||||
|
||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||
|
||||
## PEFT
|
||||
> Commands in this chapter should be run under `peft` folder.
|
||||
|
||||
### Lora
|
||||
`peft_lora_clm_instruction_tuning.py` script gives an example of instruction tuning gptq quantized model's lora adapter using tools in `auto_gptq.utils.peft_utils` and `🤗 peft` on alpaca dataset.
|
||||
|
||||
To execute this script, using command like this:
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python peft_lora_clm_instruction_tuning.py --model_name_or_path PATH/TO/MODEL/DIR
|
||||
```
|
||||
|
||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||
|
||||
### AdaLora
|
||||
`peft_adalora_clm_instruction_tuning.py` script gives an example of instruction tuning gptq quantized model's adalora adapter using tools in `auto_gptq.utils.peft_utils` and `🤗 peft` on alpaca dataset.
|
||||
|
||||
To execute this script, using command like this:
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python peft_adalora_clm_instruction_tuning.py --model_name_or_path PATH/TO/MODEL/DIR
|
||||
```
|
||||
|
||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||
|
||||
|
||||
### AdaptionPrompt
|
||||
`peft_adaption_prompt_clm_instruction_tuning.py` script gives an example of instruction tuning gptq quantized model's adaption_prompt adapter(llama-adapter) using tools in `auto_gptq.utils.peft_utils` and `🤗 peft` on alpaca dataset.
|
||||
|
||||
To execute this script, using command like this:
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python peft_adaption_prompt_clm_instruction_tuning.py --model_name_or_path PATH/TO/MODEL/DIR
|
||||
```
|
||||
|
||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||
|
||||
If you want to try models other than llama, you can install peft from source using [this branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt), see [here](https://github.com/PanQiWei/peft/blob/a5f8f74f07591efe5eb3d08cb1b31b981e84a069/src/peft/tuners/adaption_prompt.py#L235)
|
||||
to check what other models are also supported, and with this branch installed, you can also use `ADAPTION_PROMPT_V2` peft type (llama-adapter-v2) by simply replace `AdaptionPromptConfig` with `AdaptionPromptV2Config` in the script.
|
|
@ -1,318 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
from argparse import ArgumentParser
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from datasets import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
random.seed(0)
|
||||
|
||||
|
||||
class CustomizedMinNewTokensLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
min_new_tokens: int = None,
|
||||
eos_token_id: int = None,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.min_new_tokens = min_new_tokens or 0
|
||||
self.current_step = 0
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.current_step += 1
|
||||
|
||||
if self._skip_process():
|
||||
return scores
|
||||
|
||||
if any(each is not None for each in [self.eos_token_id]):
|
||||
banned_mask = torch.zeros_like(scores).to(scores.device)
|
||||
if self.eos_token_id and self.current_step <= self.min_new_tokens:
|
||||
banned_mask = self._fill_banned_mask(input_ids, banned_mask, {1: [[self.eos_token_id]]})
|
||||
scores = scores.masked_fill(banned_mask.bool(), -float("inf"))
|
||||
|
||||
return scores
|
||||
|
||||
def _skip_process(self):
|
||||
if self.current_step > self.min_new_tokens:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _fill_banned_mask(
|
||||
input_ids: torch.LongTensor,
|
||||
banned_mask: torch.Tensor,
|
||||
len2words_ids: Dict[int, List[List[int]]]
|
||||
):
|
||||
for token_len, token_ids in len2words_ids.items():
|
||||
if token_len == 1:
|
||||
banned_mask[..., list(chain(*token_ids))] = 1
|
||||
elif input_ids.shape[-1] < token_len - 1:
|
||||
continue
|
||||
else:
|
||||
token_ids = torch.LongTensor(token_ids).to(input_ids.device)
|
||||
hit_masks = torch.all(
|
||||
token_ids[..., :-1].unsqueeze(0).repeat(input_ids.shape[0], 1, 1)
|
||||
== input_ids[..., -(token_ids.shape[-1] - 1):].unsqueeze(1),
|
||||
dim=-1
|
||||
)
|
||||
for idx in range(hit_masks.shape[0]):
|
||||
selected_token_ids = torch.masked_select(token_ids[..., -1], hit_masks[idx])
|
||||
if len(selected_token_ids):
|
||||
banned_mask[idx, selected_token_ids] = 1
|
||||
return banned_mask
|
||||
|
||||
|
||||
def load_data(data_path, tokenizer, n_samples, max_new_tokens):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
raw_data = random.sample(raw_data, k=min(n_samples, len(raw_data)))
|
||||
|
||||
def dummy_gen():
|
||||
return raw_data
|
||||
|
||||
def tokenize(examples):
|
||||
instructions = examples["instruction"]
|
||||
inputs = examples["input"]
|
||||
outputs = examples["output"]
|
||||
|
||||
prompts = []
|
||||
texts = []
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
for istr, inp, opt in zip(instructions, inputs, outputs):
|
||||
if inp:
|
||||
prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
|
||||
text = prompt + opt
|
||||
else:
|
||||
prompt = f"Instruction:\n{istr}\nOutput:\n"
|
||||
text = prompt + opt
|
||||
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length - max_new_tokens:
|
||||
continue
|
||||
|
||||
tokenized_data = tokenizer(text)
|
||||
|
||||
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
|
||||
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
|
||||
prompts.append(prompt)
|
||||
texts.append(text)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"prompt": prompts
|
||||
}
|
||||
|
||||
dataset = Dataset.from_generator(dummy_gen)
|
||||
|
||||
dataset = dataset.map(
|
||||
tokenize,
|
||||
batched=True,
|
||||
batch_size=len(dataset),
|
||||
num_proc=1,
|
||||
keep_in_memory=True,
|
||||
load_from_cache_file=False,
|
||||
remove_columns=["instruction", "input"]
|
||||
)
|
||||
|
||||
dataset = dataset.to_list()
|
||||
|
||||
for sample in dataset:
|
||||
sample["input_ids"] = torch.LongTensor(sample["input_ids"])
|
||||
sample["attention_mask"] = torch.LongTensor(sample["attention_mask"])
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_model_tokenizer(
|
||||
model_name_or_path: str,
|
||||
tokenizer_name_or_path: Optional[str] = None,
|
||||
from_pretrained: bool = False,
|
||||
max_memory: Optional[dict] = None,
|
||||
model_basename: Optional[str] = None,
|
||||
quantize_config: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_triton: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
use_fast_tokenizer: bool = False,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
disable_exllama: bool = False
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
|
||||
use_fast=use_fast_tokenizer,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
if from_pretrained:
|
||||
model = AutoGPTQForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=model_name_or_path,
|
||||
quantize_config=BaseQuantizeConfig(),
|
||||
max_memory=max_memory,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
else:
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
model_name_or_path,
|
||||
max_memory=max_memory,
|
||||
low_cpu_mem_usage=True,
|
||||
use_triton=use_triton,
|
||||
inject_fused_attention=inject_fused_attention,
|
||||
inject_fused_mlp=inject_fused_mlp,
|
||||
use_cuda_fp16=True,
|
||||
quantize_config=quantize_config,
|
||||
model_basename=model_basename,
|
||||
use_safetensors=use_safetensors,
|
||||
trust_remote_code=trust_remote_code,
|
||||
warmup_triton=False,
|
||||
disable_exllama=disable_exllama
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def benchmark_generation_speed(model, tokenizer, examples, generation_config):
|
||||
generation_time_list = []
|
||||
num_generated_tokens_list = []
|
||||
progress_bar = tqdm(examples)
|
||||
for example in progress_bar:
|
||||
input_ids = example["input_ids"].to(model.device)
|
||||
|
||||
start = time.time()
|
||||
outputs_ids = model.generate(
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
generation_config=generation_config,
|
||||
logits_processor=[
|
||||
CustomizedMinNewTokensLogitsProcessor(generation_config.max_new_tokens, tokenizer.eos_token_id)
|
||||
]
|
||||
)
|
||||
end = time.time()
|
||||
|
||||
generation_time_list.append(end - start)
|
||||
num_generated_tokens = 0
|
||||
for output_ids in outputs_ids:
|
||||
num_generated_tokens += len(
|
||||
[
|
||||
token_id for token_id in output_ids[len(input_ids):] if token_id != tokenizer.pad_token_id
|
||||
]
|
||||
)
|
||||
num_generated_tokens_list.append(num_generated_tokens)
|
||||
|
||||
progress_bar.set_postfix(
|
||||
num_tokens=num_generated_tokens_list[-1],
|
||||
time=generation_time_list[-1],
|
||||
speed=f"{num_generated_tokens_list[-1] / generation_time_list[-1]:.4f}tokens/s"
|
||||
)
|
||||
|
||||
total_tokens = sum(num_generated_tokens_list)
|
||||
total_seconds = sum(generation_time_list)
|
||||
logger.info(
|
||||
f"generated {total_tokens} tokens using {total_seconds} seconds, "
|
||||
f"generation speed: {total_tokens / total_seconds}tokens/s"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_name_or_path", type=str)
|
||||
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
|
||||
parser.add_argument("--from_pretrained", action="store_true")
|
||||
parser.add_argument("--model_basename", type=str, default=None)
|
||||
parser.add_argument("--quantize_config_save_dir", type=str, default=None)
|
||||
parser.add_argument("--trust_remote_code", action="store_true")
|
||||
parser.add_argument("--use_triton", action="store_true")
|
||||
parser.add_argument("--use_safetensors", action="store_true")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
parser.add_argument("--disable_exllama", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_attention", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_mlp", action="store_true")
|
||||
parser.add_argument("--num_samples", type=int, default=10)
|
||||
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
|
||||
parser.add_argument("--cpu_max_memory", type=int, default=None)
|
||||
parser.add_argument("--max_new_tokens", type=int, default=512)
|
||||
parser.add_argument("--do_sample", action="store_true")
|
||||
parser.add_argument("--num_beams", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
max_memory = dict()
|
||||
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
|
||||
if torch.cuda.is_available():
|
||||
max_memory.update(
|
||||
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
|
||||
)
|
||||
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
|
||||
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
|
||||
if not max_memory:
|
||||
max_memory = None
|
||||
|
||||
logger.info(f"max_memory: {max_memory}")
|
||||
|
||||
quantize_config = None
|
||||
if args.quantize_config_save_dir:
|
||||
quantize_config = BaseQuantizeConfig.from_pretrained(args.quantize_config_save_dir)
|
||||
|
||||
logger.info("loading model and tokenizer")
|
||||
start = time.time()
|
||||
model, tokenizer = load_model_tokenizer(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
tokenizer_name_or_path=args.tokenizer_name_or_path,
|
||||
from_pretrained=args.from_pretrained,
|
||||
max_memory=max_memory,
|
||||
model_basename=args.model_basename,
|
||||
quantize_config=quantize_config,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
use_triton=args.use_triton,
|
||||
use_safetensors=args.use_safetensors,
|
||||
use_fast_tokenizer=args.use_fast_tokenizer,
|
||||
inject_fused_attention=not args.no_inject_fused_attention,
|
||||
inject_fused_mlp=not args.no_inject_fused_mlp,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
end = time.time()
|
||||
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")
|
||||
logger.info(f"model quantized: {model.quantized}")
|
||||
logger.info(f"quantize config: {model.quantize_config.to_dict()}")
|
||||
logger.info(f"model device map: {model.hf_device_map}")
|
||||
|
||||
if args.use_triton:
|
||||
logger.info("warmup triton, this may take a while.")
|
||||
model.warmup_triton()
|
||||
|
||||
logger.info("loading data")
|
||||
examples = load_data(
|
||||
"../quantization/dataset/alpaca_data_cleaned.json", tokenizer, args.num_samples, args.max_new_tokens
|
||||
)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
num_beams=args.num_beams,
|
||||
num_return_sequences=args.num_beams,
|
||||
do_sample=args.do_sample,
|
||||
min_new_tokens=args.max_new_tokens,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
logger.info(f"generation config: {generation_config.to_dict()}")
|
||||
|
||||
logger.info(f"benchmark generation speed")
|
||||
benchmark_generation_speed(model, tokenizer, examples, generation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
main()
|
|
@ -1,88 +0,0 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from auto_gptq.utils import Perplexity
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Example usage.
|
||||
|
||||
Default usage with GPT2 model:
|
||||
python examples/benchmark/perplexity.py
|
||||
|
||||
Specify GPTQ quantized model:
|
||||
python examples/benchmark/perplexity.py \
|
||||
--model_name TheBloke/open-llama-7b-open-instruct-GPTQ \
|
||||
--model_basename gptq_model-4bit-128g \
|
||||
--is_quantized
|
||||
|
||||
Change your dataset:
|
||||
python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare
|
||||
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
|
||||
parser.add_argument("--model_name", type=str, default='gpt2', help="Model name.")
|
||||
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
|
||||
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
|
||||
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
|
||||
parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
|
||||
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
|
||||
parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
|
||||
parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
|
||||
parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
|
||||
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
|
||||
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
max_memory = dict()
|
||||
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
|
||||
if torch.cuda.is_available():
|
||||
max_memory.update(
|
||||
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
|
||||
)
|
||||
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
|
||||
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
|
||||
if not max_memory:
|
||||
max_memory = None
|
||||
|
||||
if args.is_quantized:
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
args.model_name,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
model_basename=args.model_basename,
|
||||
use_safetensors=args.use_safetensors,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
inject_fused_mlp=False,
|
||||
inject_fused_attention=False,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
else:
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
|
||||
ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column)
|
||||
ppl.calculate_perplexity(args.n_ctx, args.n_batch)
|
|
@ -1,169 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from auto_gptq import AutoGPTQForCausalLM, get_gptq_peft_model
|
||||
from auto_gptq.utils.data_utils import make_data_block, collate_data
|
||||
from auto_gptq.utils.peft_utils import GPTQAdaLoraConfig
|
||||
from peft import TaskType
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_name_or_path", type=str)
|
||||
parser.add_argument("--lr", type=float, default=3e-3)
|
||||
parser.add_argument("--num_epochs", type=int, default=1)
|
||||
parser.add_argument("--sample_max_length", type=int, default=1024, help="max length of sample")
|
||||
parser.add_argument("--block_max_length", type=int, default=1024, help="max length of data block(bunch of samples)")
|
||||
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
model_name_or_path = args.model_name_or_path
|
||||
tokenizer_name_or_path = args.tokenizer_name_or_path or model_name_or_path
|
||||
|
||||
lr = args.lr
|
||||
num_epochs = args.num_epochs
|
||||
|
||||
# creating model
|
||||
peft_config = GPTQAdaLoraConfig(
|
||||
init_r=20,
|
||||
target_r=16,
|
||||
beta1=0.85,
|
||||
beta2=0.85,
|
||||
tinit=200,
|
||||
tfinal=1000,
|
||||
deltaT=10,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=args.use_fast_tokenizer)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
model_name_or_path,
|
||||
use_triton=True,
|
||||
warmup_triton=False,
|
||||
trainable=True,
|
||||
inject_fused_attention=True,
|
||||
inject_fused_mlp=False
|
||||
)
|
||||
model.warmup_triton()
|
||||
device = model.device
|
||||
model = get_gptq_peft_model(model, peft_config=peft_config, auto_find_all_linears=True, train_mode=True)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# loading dataset
|
||||
WITH_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n"
|
||||
WITHOUT_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Output:\n"
|
||||
|
||||
|
||||
def ds_refactor_fn(samples):
|
||||
instruction_data = samples["instruction"]
|
||||
input_data = samples["input"]
|
||||
output_data = samples["output"]
|
||||
|
||||
new_samples = {"prompt": [], "output": []}
|
||||
for instruction_txt, input_txt, output_txt in zip(instruction_data, input_data, output_data):
|
||||
if input_txt:
|
||||
prompt = WITH_INPUT_TEMPLATE.format(instruction=instruction_txt, input=input_txt)
|
||||
else:
|
||||
prompt = WITHOUT_INPUT_TEMPLATE.format(instruction=instruction_txt)
|
||||
new_samples["prompt"].append(prompt)
|
||||
new_samples["output"].append(output_txt)
|
||||
|
||||
return new_samples
|
||||
|
||||
|
||||
ds = Dataset.from_generator(
|
||||
lambda: json.load(open("../quantization/dataset/alpaca_data_cleaned.json", "r", encoding="utf-8"))
|
||||
)
|
||||
ds = ds.map(
|
||||
make_data_block,
|
||||
batched=True,
|
||||
batch_size=len(ds),
|
||||
num_proc=1,
|
||||
remove_columns=ds.column_names,
|
||||
keep_in_memory=True,
|
||||
load_from_cache_file=False,
|
||||
fn_kwargs={
|
||||
"prompt_col_name": "prompt",
|
||||
"label_col_name": "output",
|
||||
"tokenizer": tokenizer,
|
||||
"preprocess_fn": ds_refactor_fn,
|
||||
"sample_max_len": args.sample_max_length,
|
||||
"block_max_len": args.block_max_length,
|
||||
"add_eos_token": True,
|
||||
"truncate_prompt": False,
|
||||
"merge_prompt_label": True
|
||||
}
|
||||
)
|
||||
ds = ds.train_test_split(test_size=len(ds) // 10)
|
||||
train_ds, eval_ds = ds["train"], ds["test"]
|
||||
collate_fn = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
|
||||
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=partial(collate_fn))
|
||||
eval_dataloader = DataLoader(eval_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs),
|
||||
)
|
||||
model.base_model.peft_config["default"].total_step = len(train_dataloader) * num_epochs
|
||||
|
||||
# training and evaluation
|
||||
with torch.cuda.amp.autocast():
|
||||
global_step = 0
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
progress_bar = tqdm(train_dataloader)
|
||||
for step, batch in enumerate(progress_bar):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
total_loss += loss.detach().float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
# Update the importance of low-rank matrices
|
||||
# and allocate the budget accordingly.
|
||||
model.base_model.update_and_allocate(global_step)
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
eval_preds = []
|
||||
for step, batch in enumerate(tqdm(eval_dataloader)):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
eval_loss += loss.detach().float()
|
||||
eval_preds.extend(
|
||||
tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
|
||||
)
|
||||
|
||||
eval_epoch_loss = eval_loss / len(eval_dataloader)
|
||||
eval_ppl = torch.exp(eval_epoch_loss)
|
||||
train_epoch_loss = total_loss / len(train_dataloader)
|
||||
train_ppl = torch.exp(train_epoch_loss)
|
||||
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")
|
||||
|
||||
model.save_pretrained(os.path.join(model_name_or_path, f"gptq_{peft_config.peft_type.value}_adapter"))
|
|
@ -1,158 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from auto_gptq import AutoGPTQForCausalLM, get_gptq_peft_model
|
||||
from auto_gptq.utils.data_utils import make_data_block, collate_data
|
||||
from peft import TaskType, AdaptionPromptConfig
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_name_or_path", type=str)
|
||||
parser.add_argument("--adapter_len", type=int, default=10)
|
||||
parser.add_argument("--adapter_layers", type=int, default=30)
|
||||
parser.add_argument("--lr", type=float, default=3e-3)
|
||||
parser.add_argument("--num_epochs", type=int, default=1)
|
||||
parser.add_argument("--sample_max_length", type=int, default=1024, help="max length of sample")
|
||||
parser.add_argument("--block_max_length", type=int, default=1024, help="max length of data block(bunch of samples)")
|
||||
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
model_name_or_path = args.model_name_or_path
|
||||
tokenizer_name_or_path = args.tokenizer_name_or_path or model_name_or_path
|
||||
|
||||
lr = args.lr
|
||||
num_epochs = args.num_epochs
|
||||
|
||||
# creating model
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_len=args.adapter_len,
|
||||
adapter_layers=args.adapter_layers,
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=args.use_fast_tokenizer)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
model_name_or_path,
|
||||
use_triton=True,
|
||||
warmup_triton=False,
|
||||
trainable=True,
|
||||
inject_fused_attention=False,
|
||||
inject_fused_mlp=False
|
||||
)
|
||||
model.warmup_triton()
|
||||
device = model.device
|
||||
model = get_gptq_peft_model(model, peft_config=peft_config, auto_find_all_linears=True, train_mode=True)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# loading dataset
|
||||
WITH_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n"
|
||||
WITHOUT_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Output:\n"
|
||||
|
||||
|
||||
def ds_refactor_fn(samples):
|
||||
instruction_data = samples["instruction"]
|
||||
input_data = samples["input"]
|
||||
output_data = samples["output"]
|
||||
|
||||
new_samples = {"prompt": [], "output": []}
|
||||
for instruction_txt, input_txt, output_txt in zip(instruction_data, input_data, output_data):
|
||||
if input_txt:
|
||||
prompt = WITH_INPUT_TEMPLATE.format(instruction=instruction_txt, input=input_txt)
|
||||
else:
|
||||
prompt = WITHOUT_INPUT_TEMPLATE.format(instruction=instruction_txt)
|
||||
new_samples["prompt"].append(prompt)
|
||||
new_samples["output"].append(output_txt)
|
||||
|
||||
return new_samples
|
||||
|
||||
|
||||
ds = Dataset.from_generator(
|
||||
lambda: json.load(open("../quantization/dataset/alpaca_data_cleaned.json", "r", encoding="utf-8"))
|
||||
)
|
||||
ds = ds.map(
|
||||
make_data_block,
|
||||
batched=True,
|
||||
batch_size=len(ds),
|
||||
num_proc=1,
|
||||
remove_columns=ds.column_names,
|
||||
keep_in_memory=True,
|
||||
load_from_cache_file=False,
|
||||
fn_kwargs={
|
||||
"prompt_col_name": "prompt",
|
||||
"label_col_name": "output",
|
||||
"tokenizer": tokenizer,
|
||||
"preprocess_fn": ds_refactor_fn,
|
||||
"sample_max_len": args.sample_max_length,
|
||||
"block_max_len": args.block_max_length,
|
||||
"add_eos_token": True,
|
||||
"truncate_prompt": False,
|
||||
"merge_prompt_label": True
|
||||
}
|
||||
)
|
||||
ds = ds.train_test_split(test_size=len(ds) // 10)
|
||||
train_ds, eval_ds = ds["train"], ds["test"]
|
||||
collate_fn = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
|
||||
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=partial(collate_fn))
|
||||
eval_dataloader = DataLoader(eval_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs),
|
||||
)
|
||||
|
||||
# training and evaluation
|
||||
with torch.cuda.amp.autocast():
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
progress_bar = tqdm(train_dataloader)
|
||||
for step, batch in enumerate(progress_bar):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
total_loss += loss.detach().float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
eval_preds = []
|
||||
for step, batch in enumerate(tqdm(eval_dataloader)):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
eval_loss += loss.detach().float()
|
||||
eval_preds.extend(
|
||||
tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
|
||||
)
|
||||
|
||||
eval_epoch_loss = eval_loss / len(eval_dataloader)
|
||||
eval_ppl = torch.exp(eval_epoch_loss)
|
||||
train_epoch_loss = total_loss / len(train_dataloader)
|
||||
train_ppl = torch.exp(train_epoch_loss)
|
||||
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")
|
||||
|
||||
model.save_pretrained(os.path.join(model_name_or_path, f"gptq_{peft_config.peft_type.value}_adapter"))
|
|
@ -1,158 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from auto_gptq import AutoGPTQForCausalLM, get_gptq_peft_model
|
||||
from auto_gptq.utils.data_utils import make_data_block, collate_data
|
||||
from auto_gptq.utils.peft_utils import GPTQLoraConfig
|
||||
from peft import TaskType
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_name_or_path", type=str)
|
||||
parser.add_argument("--lr", type=float, default=3e-5)
|
||||
parser.add_argument("--num_epochs", type=int, default=1)
|
||||
parser.add_argument("--sample_max_length", type=int, default=1024, help="max length of sample")
|
||||
parser.add_argument("--block_max_length", type=int, default=1024, help="max length of data block(bunch of samples)")
|
||||
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
model_name_or_path = args.model_name_or_path
|
||||
tokenizer_name_or_path = args.tokenizer_name_or_path or model_name_or_path
|
||||
|
||||
lr = args.lr
|
||||
num_epochs = args.num_epochs
|
||||
|
||||
# creating model
|
||||
peft_config = GPTQLoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=args.use_fast_tokenizer)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
model_name_or_path,
|
||||
use_triton=True,
|
||||
warmup_triton=False,
|
||||
trainable=True,
|
||||
inject_fused_attention=True,
|
||||
inject_fused_mlp=False
|
||||
)
|
||||
model.warmup_triton()
|
||||
device = model.device
|
||||
model = get_gptq_peft_model(model, peft_config=peft_config, auto_find_all_linears=True, train_mode=True)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# loading dataset
|
||||
WITH_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n"
|
||||
WITHOUT_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Output:\n"
|
||||
|
||||
|
||||
def ds_refactor_fn(samples):
|
||||
instruction_data = samples["instruction"]
|
||||
input_data = samples["input"]
|
||||
output_data = samples["output"]
|
||||
|
||||
new_samples = {"prompt": [], "output": []}
|
||||
for instruction_txt, input_txt, output_txt in zip(instruction_data, input_data, output_data):
|
||||
if input_txt:
|
||||
prompt = WITH_INPUT_TEMPLATE.format(instruction=instruction_txt, input=input_txt)
|
||||
else:
|
||||
prompt = WITHOUT_INPUT_TEMPLATE.format(instruction=instruction_txt)
|
||||
new_samples["prompt"].append(prompt)
|
||||
new_samples["output"].append(output_txt)
|
||||
|
||||
return new_samples
|
||||
|
||||
|
||||
ds = Dataset.from_generator(
|
||||
lambda: json.load(open("../quantization/dataset/alpaca_data_cleaned.json", "r", encoding="utf-8"))
|
||||
)
|
||||
ds = ds.map(
|
||||
make_data_block,
|
||||
batched=True,
|
||||
batch_size=len(ds),
|
||||
num_proc=1,
|
||||
remove_columns=ds.column_names,
|
||||
keep_in_memory=True,
|
||||
load_from_cache_file=False,
|
||||
fn_kwargs={
|
||||
"prompt_col_name": "prompt",
|
||||
"label_col_name": "output",
|
||||
"tokenizer": tokenizer,
|
||||
"preprocess_fn": ds_refactor_fn,
|
||||
"sample_max_len": args.sample_max_length,
|
||||
"block_max_len": args.block_max_length,
|
||||
"add_eos_token": True,
|
||||
"truncate_prompt": False,
|
||||
"merge_prompt_label": True
|
||||
}
|
||||
)
|
||||
ds = ds.train_test_split(test_size=len(ds) // 10)
|
||||
train_ds, eval_ds = ds["train"], ds["test"]
|
||||
collate_fn = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
|
||||
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=partial(collate_fn))
|
||||
eval_dataloader = DataLoader(eval_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs),
|
||||
)
|
||||
|
||||
# training and evaluation
|
||||
with torch.cuda.amp.autocast():
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
progress_bar = tqdm(train_dataloader)
|
||||
for step, batch in enumerate(progress_bar):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
total_loss += loss.detach().float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
eval_preds = []
|
||||
for step, batch in enumerate(tqdm(eval_dataloader)):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
eval_loss += loss.detach().float()
|
||||
eval_preds.extend(
|
||||
tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
|
||||
)
|
||||
|
||||
eval_epoch_loss = eval_loss / len(eval_dataloader)
|
||||
eval_ppl = torch.exp(eval_epoch_loss)
|
||||
train_epoch_loss = total_loss / len(train_dataloader)
|
||||
train_ppl = torch.exp(train_epoch_loss)
|
||||
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")
|
||||
|
||||
model.save_pretrained(os.path.join(model_name_or_path, f"gptq_{peft_config.peft_type.value}_adapter"))
|
55
examples/push_to_hub/push_quantized_model_to_hf_hub.py
Normal file
55
examples/push_to_hub/push_quantized_model_to_hf_hub.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--quantized_model_dir", type=str, help="Directory that saves quantized model.")
|
||||
parser.add_argument("--repo_id", type=str, help="The name of the repository you want to push to.")
|
||||
parser.add_argument(
|
||||
"--tokenizer_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory that saves tokenizer, defaults to None, will not upload tokenizer if not specified."
|
||||
)
|
||||
parser.add_argument("--commit_message", type=str, default=None, help="Message to commit while pushing.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="Which device to load the model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Whether or not the repository created should be private."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_temp_dir",
|
||||
action="store_true",
|
||||
help="Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
push_to_hub_kwargs = {
|
||||
"repo_id": args.repo_id,
|
||||
"commit_message": args.commit_message,
|
||||
"private": args.private,
|
||||
"use_temp_dir": args.use_temp_dir
|
||||
}
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device=args.device)
|
||||
model.push_to_hub(**push_to_hub_kwargs)
|
||||
model.config.push_to_hub(**push_to_hub_kwargs)
|
||||
model.quantize_config.push_to_hub(**push_to_hub_kwargs)
|
||||
|
||||
if args.tokenizer_dir:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer.push_to_hub(**push_to_hub_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -7,6 +7,8 @@ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
|||
pretrained_model_dir = "facebook/opt-125m"
|
||||
quantized_model_dir = "opt-125m-4bit-128g"
|
||||
|
||||
# os.makedirs(quantized_model_dir, exist_ok=True)
|
||||
|
||||
|
||||
def main():
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||
|
@ -19,46 +21,29 @@ def main():
|
|||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4, # quantize model to 4-bit
|
||||
group_size=128, # it is recommended to set the value to 128
|
||||
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||||
)
|
||||
|
||||
# load un-quantized model, by default, the model will always be loaded into CPU memory
|
||||
# load un-quantized model, the model will always be force loaded into cpu
|
||||
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||
|
||||
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
|
||||
model.quantize(examples)
|
||||
# quantize model, the examples should be list of dict whose keys contains "input_ids" and "attention_mask"
|
||||
# with value under torch.LongTensor type.
|
||||
model.quantize(examples, use_triton=False)
|
||||
|
||||
# save quantized model
|
||||
model.save_quantized(quantized_model_dir)
|
||||
|
||||
# push quantized model to Hugging Face Hub.
|
||||
# to use use_auth_token=True, Login first via huggingface-cli login.
|
||||
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
|
||||
# (uncomment the following three lines to enable this feature)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# alternatively you can save and push at the same time
|
||||
# (uncomment the following three lines to enable this feature)
|
||||
# repo_id = f"YourUserName/{quantized_model_dir}"
|
||||
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
|
||||
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
|
||||
|
||||
# save quantized model using safetensors
|
||||
model.save_quantized(quantized_model_dir, use_safetensors=True)
|
||||
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
|
||||
|
||||
# download quantized model from Hugging Face Hub and load to the first GPU
|
||||
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
|
||||
# load quantized model, currently only support cpu or single gpu
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
|
||||
|
||||
# inference with model.generate
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
|
||||
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to("cuda:0"))[0]))
|
||||
|
||||
# or you can also use pipeline
|
||||
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, device="cuda:0")
|
||||
print(pipeline("auto-gptq is")[0]["generated_text"])
|
||||
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue