Compare commits

...

554 commits
v0.0.5 ... main

Author SHA1 Message Date
ced04e1dff disable the error exit here, see if the pregen code works 2023-10-26 12:43:07 -04:00
07021b9a1c Generated files so that when they fail to work in pipeline then it still continues with what should be some ok defaults 2023-10-26 10:26:42 -04:00
3011e13009 Built locally for temp setup, not sure what its doing but it is doing weird stuff on build server, like it never determines something 2023-10-26 10:26:13 -04:00
153c085a32 Make this fail early when the actual problem happens 2023-10-26 09:38:59 -04:00
Automation Pipeline
9fb99f61e7 Merge remote-tracking branches 'laaza/Mistral' and 'laaza/MPT' 2023-10-22 07:53:59 -04:00
Vivek Khandelwal
e4b2493733
Modify qlinear_cuda for tracing the GPTQ model (#367)
Changes:
-- The change to the torch.bitwise_and is done because during
   tracing this model the current usage of the torch.bitwise_and
   result in an in-place variant of this op, resulting in an issue
   during the downstream lowering pipeline of the traced model via
   Torch-MLIR and IREE-SHARK. That's why the op usage is changed to
   not result in an in-place variaunt.

-- The change to the torch.matmul call in the forward function is
   done because currently, it assumes that the weights will always
   be of fp16 type. But, when the model is executed for the float32
   weights it results in an error. That's why the current change
   cast the LHS of the matmul to the same type as the RHS one.

Both the above changes doesn't affect the model in any way.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-10-21 01:06:01 +09:00
LaaZa
4b7389ddb7 Merge branch 'main' into MPT
# Conflicts:
#	auto_gptq/modeling/__init__.py
#	auto_gptq/modeling/_const.py
#	auto_gptq/modeling/auto.py
2023-10-04 20:21:49 +03:00
LaaZa
99acbead42 Add support for Mistral models. 2023-10-04 01:07:55 +03:00
潘其威(William)
51c043c6be
Merge pull request #355 from PanQiWei/fix_pack_model_use_exllamav2
import exllama QuantLinear instead of exllamav2's in `pack_model`
2023-09-27 11:06:35 +08:00
student686
c1a3013c45 import exllama QuantLinear instead of exllamav2's 2023-09-27 11:05:13 +08:00
潘其威(William)
3b81fb5ea0
Merge pull request #354 from PanQiWei/revert-325-main
Reverts #325 for it may breaks exllama kernels
2023-09-27 10:39:00 +08:00
潘其威(William)
3de7fbb0d5
Revert "fix bug(breaking change) remove (zeors -= 1)" 2023-09-27 10:37:31 +08:00
潘其威(William)
ac23d6b819
Merge pull request #325 from qwopqwop200/main
remove an unnecessary line (zeors -= 1) to make disable 'sym' feature truely possible
2023-09-26 14:20:39 +08:00
潘其威(William)
62fd0371ac
Merge branch 'main' into main 2023-09-26 14:09:04 +08:00
潘其威(William)
b461b6fa13
Merge pull request #335 from z80maniac/ignore-extra-args
Ignore unknown parameters in quantize_config.json
2023-09-26 14:00:38 +08:00
潘其威(William)
04db761eed
Merge pull request #347 from alex4321/peft-model-use-adapter-name
Use `adapter_name` for `get_gptq_peft_model` with `train_mode=True`
2023-09-26 13:55:06 +08:00
潘其威(William)
50d2e86890
Merge pull request #349 from SunMarc/exllamav2_integration
exllamav2 integration
2023-09-26 13:49:59 +08:00
Marc Sun
c912bf361a exllamav2 integration 2023-09-25 16:51:18 +00:00
student686
645bd15a96 update README 2023-09-25 18:55:34 +08:00
student686
d2844437fd update README 2023-09-25 18:53:03 +08:00
student686
da84da846b update README 2023-09-25 18:51:03 +08:00
student686
50da063f65 update README 2023-09-25 18:47:40 +08:00
Alexander Pozharskii
0185095402 Use adapter_name for get_gptq_peft_model with train_mode=True 2023-09-24 17:11:19 +04:00
潘其威(William)
06e071e68e
Merge pull request #326 from TheBloke/TB_Latest_Falcon
Add support for Falcon as part of Transformers 4.33.0, including new Falcon 180B
2023-09-14 22:49:25 +08:00
PanQiWei
7a75176224 update README 2023-09-11 11:15:08 +08:00
ZXED
121dbd15a5
Ignore unknown parameters in quantize_config.json 2023-09-10 18:39:40 +03:00
qwopqwop200
94de4ef185
GPTQ backward compatibility support 2023-09-08 10:16:29 +09:00
qwopqwop200
9e0682a63e
Optimize q4_matmul
https://github.com/turboderp/exllama/pull/275
2023-09-07 12:54:46 +09:00
TheBloke
034f6730ed Removed unexpected file that shouldn't have been added, sorry 2023-09-06 18:08:30 +01:00
TheBloke
02a87dce76 Add support for Falcon as part of Transformers 4.33.0, including new Falcon 180B 2023-09-06 18:03:33 +01:00
qwopqwop200
6b1ceb1897
if exllama auto diable fused attention 2023-09-06 18:14:04 +09:00
qwopqwop200
ad5b0d72ee
fix bug 2023-09-06 16:41:41 +09:00
qwopqwop200
f752336cda
fix bug 2023-09-06 16:39:22 +09:00
潘其威(William)
1793227283
Merge pull request #311 from SunMarc/fix_max_input_length
fix typo in max_input_length
2023-09-01 10:21:54 +08:00
潘其威(William)
782bb603d9
Merge pull request #303 from JustinLin610/patch-1
Update qwen.py for Qwen-VL
2023-09-01 10:20:24 +08:00
Marc Sun
04b321da89
fix type 2023-08-31 14:07:16 -04:00
潘其威(William)
1e938e6bad
Merge pull request #310 from PanQiWei/fix_to()_metod_bug
fix model type changed after calling .to() method
2023-08-31 19:04:02 +08:00
潘其威(William)
1339db3045
Merge pull request #309 from PanQiWei/install-skip-qigen(windows)
skip qigen installation on windows
2023-08-31 19:03:43 +08:00
PanQiWei
c7021f0f44 fix model type changed after calling .to() method 2023-08-31 18:39:03 +08:00
qwopqwop200
f97b77a64e
fix install bug 2023-08-31 15:00:38 +09:00
qwopqwop200
45a1ee4d84
install check qigen 2023-08-31 14:37:39 +09:00
qwopqwop200
71d56c76d0
skip install qigen(windows) 2023-08-31 14:35:04 +09:00
Junyang Lin
7c39a3a315
Update qwen.py for Qwen-VL
add transformer.visual as outside layer for the adaptation to Qwen-VL
2023-08-30 16:29:55 +08:00
PanQiWei
604c96144f temporarily set the version of main branch to 0.5.0.dev0 2023-08-25 17:36:23 +08:00
潘其威(William)
6bbf70373f
Merge pull request #288 from PanQiWei/revert-287-v0.4.2-release
Revert "V0.4.2 release"
2023-08-25 17:34:27 +08:00
潘其威(William)
e5050a5650
Revert "V0.4.2 release" 2023-08-25 17:26:55 +08:00
潘其威(William)
1049fd014a
Merge pull request #287 from PanQiWei/v0.4.2-release
V0.4.2 release
2023-08-25 17:26:41 +08:00
qwopqwop200
6a9d80eddc Merge remote-tracking branch 'qwopqwop200/main' into main 2023-08-25 18:06:03 +09:00
qwopqwop200
dafdd6189a
duplicate code remove 2023-08-25 14:59:13 +09:00
fxmarty
144302f58f
Update install instructions (#286) 2023-08-25 04:17:25 +09:00
fxmarty
ef442d9f70 Fix setuptools classifier (#285) 2023-08-24 19:34:10 +02:00
fxmarty
0365188c9c
Fix setuptools classifier (#285) 2023-08-25 02:33:28 +09:00
Félix Marty
8254da4f15 update version 2023-08-24 17:47:14 +02:00
fxmarty
10e6fda832
fix powershell (#284) 2023-08-24 23:53:07 +09:00
fxmarty
cf942da9e2
remove ref main as we may want to trigger workflows on other branches (#282) 2023-08-24 22:55:13 +09:00
PanQiWei
78082b1c5e update README 2023-08-24 21:16:04 +08:00
潘其威(William)
8bb4d60d8f
Merge pull request #281 from fxmarty/expose-api-exllama-input-length
Expose a function to update exllama max input length
2023-08-24 20:50:18 +08:00
Felix Marty
04730ac66c expose api to set exllama max length 2023-08-24 11:22:15 +00:00
fxmarty
3cd79c826e
Fix python version for rocm build (#278)
* fix python version

* whats the diff?
2023-08-23 23:01:22 +09:00
fxmarty
766c6c1956
fix (#277) 2023-08-23 21:50:18 +09:00
fxmarty
d53d227b7c
Update install instructions (#275)
* update readme

* update doc

* fix
2023-08-23 21:29:55 +09:00
fxmarty
d0d1a69931
use conda incubator (#276) 2023-08-23 21:18:46 +09:00
fxmarty
81801bc6e2
Use focal for RoCm build (#274) 2023-08-23 20:41:08 +09:00
fxmarty
f7b1b8291a
Free disk space for rocm build (#273) 2023-08-23 19:21:44 +09:00
fxmarty
48baeeb739
Merge pull request #272 from PanQiWei/build-wheels-on-2004
Build wheels on ubuntu 20.04
2023-08-23 18:48:25 +09:00
Félix Marty
064f74c60f update ubuntu version 2023-08-23 11:46:19 +02:00
PanQiWei
40945beb0e update README 2023-08-22 20:18:59 +08:00
PanQiWei
4160db15e9 update README 2023-08-22 17:24:22 +08:00
qwopqwop200
f23a06f911
Merge branch 'PanQiWei:main' into main 2023-08-17 15:22:43 +09:00
qwopqwop200
b8a42911a6
qigen refactoring 2023-08-17 15:22:16 +09:00
qwopqwop200
5d5b687ca8
qigen formatting qlinear 2023-08-17 15:19:01 +09:00
qwopqwop200
084c9d8860
name change 2023-08-17 15:17:09 +09:00
潘其威(William)
eea67b7e13
Merge pull request #256 from PanQiWei/rocm_build_bug_fix
Rocm build bug fix
2023-08-13 17:14:40 +08:00
PanQiWei
8542b3dc9f execute setup tools install before torch install 2023-08-13 16:49:21 +08:00
PanQiWei
79b697743f disable free disk space action 2023-08-13 16:41:20 +08:00
PanQiWei
893fc5d7a3 release 0.4.1 2023-08-13 16:35:59 +08:00
PanQiWei
34b4ba451c fix typo 2023-08-13 16:26:02 +08:00
qwopqwop200
051f3facc7
change arguments name 2023-08-11 16:10:32 +09:00
qwopqwop200
a807e038bb
remove many contiguous and change arguments name 2023-08-11 16:09:42 +09:00
qwopqwop200
c591d6a1e1
change name make_quant_cpu to make_quant_qigen 2023-08-11 15:12:33 +09:00
qwopqwop200
2c1afc2ad9
chang name make_quant_cpu to make_quant_qigen 2023-08-11 15:04:58 +09:00
qwopqwop200
aa5528cb10
use_cpu name change and default dtype change 2023-08-11 09:51:36 +09:00
qwopqwop200
870be83bea
Merge branch 'PanQiWei:main' into main 2023-08-10 22:48:30 +09:00
qwopqwop200
7ba78af3ae support cpu 2023-08-10 22:48:04 +09:00
潘其威(William)
1832685121
Merge pull request #243 from fxmarty/patch-act-order-exllama
Patch exllama QuantLinear to avoid modifying the state dict
2023-08-10 11:15:10 +08:00
qwopqwop200
1b3723a584 install qigen and move file 2023-08-10 10:06:08 +09:00
Felix Marty
4af7ea619d patch for transformers compatiblity 2023-08-09 14:23:59 +00:00
PanQiWei
6a277c87cf fix syntax error 2023-08-09 19:44:29 +08:00
PanQiWei
d178ebd2fe set branch to main 2023-08-09 19:39:05 +08:00
PanQiWei
9978d6e9f9 set branch to rocm_build_bug_fix 2023-08-09 19:37:05 +08:00
PanQiWei
aea761042d fix only one python version used 2023-08-09 19:36:37 +08:00
PanQiWei
69cdfe80fd fix syntax error 2023-08-09 18:05:30 +08:00
PanQiWei
44c7a1a184 make exllama_kernels compilation as optional 2023-08-09 17:42:22 +08:00
PanQiWei
e30bb69dee revert to remove 3.11 support 2023-08-09 14:55:42 +08:00
PanQiWei
918842a083 Revert "remove 3.11 for now"
This reverts commit b5a7c813e3.
2023-08-09 14:51:11 +08:00
PanQiWei
c5acab3aec Revert "remove upload sdist step"
This reverts commit 9baff43f6f.
2023-08-09 14:49:13 +08:00
PanQiWei
b5a7c813e3 remove 3.11 for now 2023-08-09 14:34:03 +08:00
PanQiWei
60ea23d464 fix py3.11 can't build 2023-08-09 14:05:44 +08:00
PanQiWei
115f004c5e fix wrong index-url 2023-08-09 13:37:19 +08:00
PanQiWei
6d5ce1d386 temporarily disable rocm 5.5 and 5.6 support until pytorch 2.1.0 is officially released 2023-08-09 13:35:52 +08:00
PanQiWei
db9eabfc4b add disable_exllama argument 2023-08-09 12:05:15 +08:00
PanQiWei
172deae049 expose disable_exllama argument 2023-08-09 12:03:31 +08:00
PanQiWei
86a3d4a094 release 0.4.0 2023-08-09 11:54:31 +08:00
潘其威(William)
3fb7d1ed1c
Merge pull request #240 from PanQiWei/support-qwen
support qwen
2023-08-08 19:24:24 +08:00
qwopqwop200
fe244503e0
add "," 2023-08-08 19:57:23 +09:00
qwopqwop200
d22f89c524
support qwen 2023-08-08 19:27:43 +09:00
潘其威(William)
5981f15dc3
Merge pull request #236 from PanQiWei/suppprt-static_groups
Suppprt static groups and fix bug
2023-08-08 14:29:39 +08:00
qwopqwop200
dc5541e78a
static groups default value change 2023-08-08 14:11:39 +09:00
qwopqwop200
2f48780165
fix bug disable exlllama 2023-08-07 16:28:30 +09:00
qwopqwop200
25972d65bf
support static_groups and fix bug 2023-08-07 16:27:48 +09:00
qwopqwop200
6233afce3b
support static_groups 2023-08-07 16:25:44 +09:00
PanQiWei
d427489911 update README 2023-08-06 20:03:33 +08:00
潘其威(William)
9e2741c99d
Merge pull request #219 from fxmarty/exllama-q4-kernel
Add exllama q4 kernel
2023-08-06 19:53:40 +08:00
fxmarty
71f23268eb
Merge pull request #1 from qwopqwop200/exllama-q4-kernel
Exllama q4 kernel
2023-08-05 00:15:22 +09:00
Felix Marty
c203a85dee fix 2023-08-04 15:00:12 +00:00
Felix Marty
d0608b09db rocm support 2023-08-04 13:38:02 +00:00
Félix Marty
4fb3e20c5e Merge branch 'main' into exllama-q4-kernel 2023-08-04 15:13:34 +02:00
潘其威(William)
5d8fa85029
Merge pull request #226 from LeiWang1999/fix/general_attr
Register quant params in GeneralQuantLinear for friendly post process.
2023-08-04 18:42:54 +08:00
潘其威(William)
45152b7add
Merge pull request #220 from fxmarty/fix-revison-loading
Fix revision used to load the quantization config
2023-08-04 18:25:22 +08:00
潘其威(William)
63da0cd00a
Merge pull request #214 from fxmarty/rocm-support
Add RoCm support
2023-08-04 18:24:15 +08:00
潘其威(William)
e6790ba2cb
Update README.md 2023-08-04 18:16:55 +08:00
qwopqwop200
79ab5076c7
revert fused_llama_attn.py 2023-08-04 18:19:54 +09:00
qwopqwop200
068210d0b7
exllama support flash attention 2023-08-03 16:30:16 +09:00
qwopqwop200
7a7df5655a
support group query attention 2023-08-03 16:08:49 +09:00
leiwang1999
a0de5c2c51 regist buffer of general quant linear 2023-08-03 05:15:09 +00:00
qwopqwop200
3fc097dcd8
change pcak func support only 4 bit 2023-08-01 20:01:45 +09:00
qwopqwop200
a1fd81c72d
if training disable exllama 2023-08-01 12:29:58 +09:00
qwopqwop200
a60c9a8552
add pack fun 2023-08-01 12:22:41 +09:00
Felix Marty
7ce182888c fix atol 2023-07-31 16:46:15 +00:00
Felix Marty
339c57a902 fix 2023-07-31 15:57:44 +00:00
Felix Marty
129fa4b67e act-order now works fine 2023-07-31 15:36:58 +00:00
Felix Marty
1f99b94ae2 fix revision 2023-07-31 15:03:33 +00:00
Felix Marty
208bbfb419 fix test 2023-07-31 14:59:42 +00:00
Felix Marty
d595b0be06 add test act-order wip (not passing) 2023-07-31 14:29:19 +00:00
Felix Marty
5660b22f28 fix bug quantization config loading 2023-07-31 14:28:37 +00:00
Felix Marty
7c72b72d15 add test 2023-07-31 13:47:43 +00:00
Felix Marty
38447262c0 fix fused attn 2023-07-31 13:46:32 +00:00
Felix Marty
760667dccc cleaning 2023-07-31 11:58:10 +00:00
Felix Marty
179776bd1d exllama kernel 2023-07-31 11:50:45 +00:00
Felix Marty
23eb519e68 typo 2023-07-28 17:45:34 +00:00
Felix Marty
caf6625b68 warning about triton 2023-07-28 17:42:37 +00:00
Felix Marty
d20540173c change repo owner back to PanQiWei 2023-07-28 16:43:58 +00:00
Felix Marty
121399f8e5 fix workflows 2023-07-28 16:42:39 +00:00
Felix Marty
677d23be2d style 2023-07-28 15:14:46 +00:00
Felix Marty
8112239848 add workflow, edit readme & add tests 2023-07-28 15:10:39 +00:00
Felix Marty
2cb191e114 fix bugs 2023-07-28 14:10:44 +00:00
Felix Marty
547fb198d1 fix 2023-07-27 12:36:25 +00:00
Félix Marty
0b8a1f922d is it as simple as that? 2023-07-27 12:14:33 +02:00
LaaZa
6ff6bc8dfc Merge branch 'main' into MPT
# Conflicts:
#	auto_gptq/modeling/__init__.py
#	auto_gptq/modeling/_const.py
#	auto_gptq/modeling/auto.py
2023-07-26 20:41:19 +03:00
PanQiWei
a7167b108c simplify setup.py 2023-07-26 19:18:05 +08:00
PanQiWei
6395e4b301 update setup.py 2023-07-26 18:58:49 +08:00
PanQiWei
d6b6ec83ef Merge remote-tracking branch 'origin/main' 2023-07-26 18:41:01 +08:00
PanQiWei
1138240385 update version to 0.3.2 2023-07-26 18:40:44 +08:00
潘其威(William)
b0889e4dab
Merge pull request #212 from casperbh96/main
Fix build on non-CUDA machines after #206
2023-07-26 18:35:53 +08:00
Casper
c68b4492f6 Fix build on non-CUDA machines after #206 2023-07-26 12:21:58 +02:00
PanQiWei
ff1f100ded remove argument 'save_dir' in method from_quantized 2023-07-26 17:58:04 +08:00
PanQiWei
722a621aaa simplified code 2023-07-26 17:53:47 +08:00
PanQiWei
5d6862ee8d update README 2023-07-26 14:18:26 +08:00
潘其威(William)
22748dd2b7
Merge pull request #209 from PanQiWei/fix_no_cuda_kernel
Fix error raised when CUDA kernels are not installed
2023-07-26 14:07:30 +08:00
潘其威(William)
fd24e84eb2
Merge pull request #166 from casperbh96/main
[FEATURE] Implement perplexity metric to compare against llama.cpp
2023-07-26 14:04:51 +08:00
PanQiWei
5883b45d73 fix error raised when cuda kernels are not installed 2023-07-26 13:59:28 +08:00
潘其威(William)
bbc4a7c455
Merge pull request #208 from TheBloke/TB_Add_SafeTensors_Metadata
Add Safetensors metadata saving, with some values saved to each .safetensor file
2023-07-26 11:54:47 +08:00
潘其威(William)
228867a753
Merge pull request #207 from TheBloke/TB_version
Add a central version number
2023-07-26 11:27:23 +08:00
潘其威(William)
cbc319b4c8
Merge pull request #206 from TheBloke/TB_InstallScript
Change the install script so it attempts to build the CUDA extension in all cases
2023-07-26 11:20:53 +08:00
潘其威(William)
2456f71125
Merge pull request #205 from TheBloke/TB_fix_revision
Fix `revision` and other huggingface_hub kwargs in .from_quantized()
2023-07-26 10:34:43 +08:00
潘其威(William)
df4c4312ff
Merge pull request #202 from PanQiWei/fix-cuda-bug
Fix cuda bug that causes group_size and desc_act can't be used together
2023-07-26 10:32:18 +08:00
TheBloke
2647c92743 safetensors_metadata: add conversion to str() for input metadata to avoid errors from save_safe. Warn if this results in keys being overwritten. 2023-07-25 21:14:21 +00:00
TheBloke
ee7d80945b Add version to metadata using new value 2023-07-25 14:25:24 +00:00
TheBloke
3817d154af Merge branch 'TB_version' into TB_Add_SafeTensors_Metadata 2023-07-25 14:09:29 +00:00
TheBloke
7575eae6ab Added to __init__.py to show a central version number. Also slightly adjust way version is stored in setup.py to make it easier to edit on version update. Bump version to 0.3.1 in both 2023-07-25 14:06:51 +00:00
TheBloke
eeaf5ebc53 Extend huggingface_hub features to AutoGPTQForCausalLM.from_pretrained() so models can be quantised from the hub including using a private token and revision/branch etc 2023-07-25 13:26:37 +00:00
TheBloke
593d32cb45 Typo in version joining 2023-07-25 13:18:52 +00:00
TheBloke
c9124e3fc7 Fix revision and other huggingface_hub args for .from_quantized(), which were not being passed through 2023-07-25 12:48:33 +00:00
TheBloke
6fc69c5b83 Fix check for Torch CUDA version 2023-07-25 12:45:27 +00:00
TheBloke
29da6c239f setup.py now builds CUDA ext unless BUILD_CUDA_EXT=0. Also add a check of CUDA_VERSION from Torch, if available. GITHUB_ACTIONS=true is no longer needed. 2023-07-25 11:44:43 +00:00
TheBloke
3f359fc778 Add support for Safetensors metadata 2023-07-25 11:30:39 +00:00
qwopqwop200
9578c59d31
fix cuda bug 2023-07-25 16:50:05 +09:00
qwopqwop200
ed2aa9368e
fix cuda buf 2023-07-25 16:46:32 +09:00
PanQiWei
45576f0933 0.3.0 release 2023-07-16 15:24:06 +08:00
潘其威(William)
c2c5a74f4b
Merge pull request #158 from MarisaKirisame/main
Fix stale documentation
2023-07-11 10:43:09 +08:00
潘其威(William)
79f8a08a6d
Merge pull request #189 from cczhong11/main
Add support for InternLM
2023-07-11 10:42:30 +08:00
tc
e28e8ee809 Add support for InternLM 2023-07-07 09:25:40 -07:00
PanQiWei
590219d048 update README 2023-07-06 17:15:50 +08:00
Casper
1949e8607d Fix usage of device 2023-06-19 20:16:16 +02:00
Casper
5b88f03bba Create example of how to evaluate perplexity 2023-06-19 20:03:42 +02:00
Casper
992a0ab102 Reference Perplexity class 2023-06-19 20:03:32 +02:00
Casper
b351c8c547 Add perplexity calculation class 2023-06-19 20:03:22 +02:00
潘其威(William)
046c031139
Merge pull request #141 from AngainorDev/patch-1
Fix error message
2023-06-19 10:11:10 +08:00
潘其威(William)
93368c4e36
Merge pull request #164 from LaaZa/Baichuan
Add support for BaiChuan model
2023-06-19 10:08:21 +08:00
LaaZa
03577a7698 Rename the class to match reference capitalisation 2023-06-18 21:01:07 +03:00
LaaZa
9fd558f2ba Add support for Baichuan 2023-06-18 20:13:29 +03:00
Marisa Kirisame
ae80f2dc72 fix stale documentation 2023-06-14 20:04:27 +00:00
PanQiWei
9baff43f6f remove upload sdist step 2023-06-08 14:07:38 +08:00
PanQiWei
1b226c7bcf revert use absolute path in include_dirs 2023-06-08 14:03:08 +08:00
PanQiWei
7520133a74 upload sdist at every job 2023-06-08 14:02:46 +08:00
PanQiWei
0d4c54add9 remove build_sdist_wheel_cpu_only.yml 2023-06-08 14:02:21 +08:00
PanQiWei
bb4924e9a8 change workflow name 2023-06-08 12:33:59 +08:00
PanQiWei
8801cdf340 build_wheels.yml -> build_wheels_cuda.yml 2023-06-08 12:32:42 +08:00
PanQiWei
67bb388bf8 remove build_sdst job 2023-06-08 12:32:17 +08:00
PanQiWei
15d1981f25 add build_sdist_wheel_cpu_only.yml 2023-06-08 12:31:35 +08:00
PanQiWei
590685cad5 use absolute path in include_dirs 2023-06-08 12:30:14 +08:00
潘其威(William)
2ea23297c6
Merge pull request #140 from geekinglcq/fix_issue95
fix weights not transpose for Conv1D/2D in qlinear_cuda_old
2023-06-06 19:54:57 +08:00
Angainor Development
e75611e1b7
Fix error message 2023-06-05 22:19:09 +02:00
lunar
618a5f50ee
Add transpose operator when replace Conv1d with qlinear_cuda_old 2023-06-05 23:11:18 +08:00
潘其威(William)
bf521cbe7b
Merge pull request #134 from TheBloke/TB_benchmark
add command flags inject_fused_attention and inject_fused_mlp
2023-06-05 23:02:36 +08:00
潘其威(William)
8a1616c63c
Merge pull request #102 from PanQiWei/peft_integration
Peft integration
2023-06-05 22:55:34 +08:00
PanQiWei
129884c598 update version to 0.3.0.dev0 2023-06-05 22:53:50 +08:00
PanQiWei
b132d774e3 update README 2023-06-05 22:53:17 +08:00
TheBloke
edb13d493e Default inject_fused_attention and mlp to True, matching defaults 2023-06-03 17:58:40 +01:00
TheBloke
4617629f0c Support setting inject_fused_attention and inject_fused_mlp to False 2023-06-03 17:48:36 +01:00
PanQiWei
923fc87a11 Merge branch 'main' into peft_integration 2023-06-03 19:10:41 +08:00
潘其威(William)
023bb1c593
Merge pull request #125 from PanQiWei/support-32dim
Support 32dim
2023-06-03 19:08:29 +08:00
潘其威(William)
95a4381f50
Merge pull request #126 from PanQiWei/support-cuda-64dim
Support cuda 64dim
2023-06-03 19:08:12 +08:00
潘其威(William)
810ed4de66
Merge pull request #132 from EliEron/patch-1
Specify UTF-8 encoding for README.md in setup.py
2023-06-03 10:57:52 +08:00
qwopqwop200
f4820f2988
change qlinear cuda support 64dim 2023-06-03 07:30:34 +09:00
qwopqwop200
8951212ab3
change setup 2023-06-03 07:29:19 +09:00
qwopqwop200
e04c3b86cc
add cuda 2023-06-03 07:28:35 +09:00
qwopqwop200
5fc2064e1a
Rename autogptq_cuda_kernel.cu to autogptq_cuda_kernel_64.cu 2023-06-03 07:27:45 +09:00
qwopqwop200
446e12d3de
Rename autogptq_cuda.cpp to autogptq_cuda_64.cpp 2023-06-03 07:27:31 +09:00
EliEron
eeb8b78a55
Specify Encoding when reading README.md
Prevents UnicodeDecodeError from being raised in certain locals.
2023-06-02 20:54:32 +02:00
LaaZa
bf47892b81 Merge branch 'main' into MPT
# Conflicts:
#	auto_gptq/modeling/__init__.py
#	auto_gptq/modeling/_const.py
#	auto_gptq/modeling/auto.py
2023-06-02 15:01:10 +03:00
潘其威(William)
b4fdd8d264
Merge branch 'main' into peft_integration 2023-06-02 19:11:59 +08:00
PanQiWei
7206705456 set version to 0.2.1 2023-06-02 19:07:56 +08:00
qwopqwop200
2df7d7105d
support 64 cuda dim 2023-06-02 19:54:37 +09:00
qwopqwop200
b03f53294f
support 64dim cuda 2023-06-02 19:53:50 +09:00
qwopqwop200
90106d7c34
support cuda 64dim 2023-06-02 19:49:38 +09:00
PanQiWei
65c0115b86 update README 2023-06-02 18:18:11 +08:00
PanQiWei
0e609bec40 only append CUDA_VERSION to release version string when in github actions 2023-06-02 18:16:38 +08:00
qwopqwop200
0891ea4036
support 32dim triton] 2023-06-02 19:05:55 +09:00
qwopqwop200
b3654a68c3
support 32dim triton kernel 2023-06-02 19:04:12 +09:00
PanQiWei
50ac2ad4bc update README 2023-06-02 10:59:36 +08:00
PanQiWei
113884d976 Merge remote-tracking branch 'origin/main' 2023-06-02 10:57:16 +08:00
PanQiWei
b248a2655a update README 2023-06-02 10:56:57 +08:00
潘其威(William)
f948b56c07
Merge pull request #123 from jllllll/main
Fix and extend build_wheels.yml workflow
2023-06-02 10:12:20 +08:00
jllllll
0f1793b554
Revert "Remove workflow restriction for testing"
This reverts commit e62bda1c1e.
2023-06-01 20:49:42 -05:00
jllllll
3c6a002be5
Clean up workflow sdist creation 2023-06-01 20:35:30 -05:00
jllllll
e62bda1c1e
Remove workflow restriction for testing 2023-06-01 20:27:40 -05:00
jllllll
996382788b
Finalize workflow fix 2023-06-01 13:58:24 -05:00
jllllll
198e079da4
Restrict build_wheels.yml to minimum compute 6.0 2023-06-01 13:25:04 -05:00
jllllll
a0063fc9db
Add GitHub Actions bypass for cuda check to setup.py 2023-06-01 13:07:00 -05:00
jllllll
3084422095
Merge branch 'PanQiWei:main' into main 2023-06-01 12:50:34 -05:00
PanQiWei
b5db750c00 update setup.py 2023-06-02 01:39:56 +08:00
jllllll
2b96343e87
Update build_wheels.yml (#1) 2023-06-01 12:39:56 -05:00
PanQiWei
6a37f7c266 update setup.py 2023-06-02 00:03:44 +08:00
PanQiWei
bc61e51394 update README 2023-06-01 10:35:17 +08:00
PanQiWei
7ae89f282a update build_wheels.yml 2023-06-01 01:48:29 +08:00
PanQiWei
31b8c1313e update build_wheels.yml 2023-06-01 01:34:46 +08:00
PanQiWei
d53a30d351 update build_wheels.yml 2023-06-01 01:16:10 +08:00
PanQiWei
ac7dd9bc1f update build_wheels.yml 2023-06-01 01:03:35 +08:00
潘其威(William)
a63f8fd523
Merge pull request #120 from PanQiWei/add_build_wheels_workflow
Add build wheels workflow
2023-06-01 00:43:38 +08:00
PanQiWei
d780ef5eef update build_wheels.yml 2023-06-01 00:42:31 +08:00
PanQiWei
407e5d8133 add workflow to build wheels 2023-06-01 00:39:09 +08:00
PanQiWei
0ece40ca25 update setup.py 2023-06-01 00:38:35 +08:00
PanQiWei
402973259f update setup.py 2023-06-01 00:18:43 +08:00
PanQiWei
ec6603d0ab support older version python 2023-05-31 22:11:16 +08:00
潘其威(William)
93698e027d
Merge pull request #116 from PanQiWei/pytorch-qlinear
switch to use pytorch backend when triton is not available at train mode
2023-05-31 00:10:55 +08:00
qwopqwop200
b1a8cc28e8
remove raise 2023-05-31 00:03:51 +09:00
qwopqwop200
c381958a5f
add warning 2023-05-30 23:53:33 +09:00
qwopqwop200
0f2841cb13
remove log 2023-05-30 23:51:55 +09:00
qwopqwop200
33809a8e59
remove log 2023-05-30 23:51:39 +09:00
qwopqwop200
dfd9dc0e6b
change if trainable backend pytorch 2023-05-30 23:43:55 +09:00
qwopqwop200
5274313067
change if trainable backend pytorch 2023-05-30 23:40:58 +09:00
PanQiWei
d0769c1a39 update README 2023-05-30 08:01:16 +08:00
PanQiWei
e826d89dbc update basic_usage.py 2023-05-30 07:47:10 +08:00
PanQiWei
df8672ce75 update README 2023-05-30 07:44:25 +08:00
PanQiWei
448a53e6a7 delete push to hub example script 2023-05-30 07:39:58 +08:00
潘其威(William)
defc96ff04
Merge pull request #91 from TheBloke/TheBloke_support-HF-download
Add support for HF Hub download, and `push_to_hub`
2023-05-30 07:37:15 +08:00
潘其威(William)
2245fad095
Update auto.py
fix None type error
2023-05-30 07:35:15 +08:00
潘其威(William)
15db2cdc44
Update _base.py
fix problem that recursively adding file extension to model_base_name
2023-05-30 07:26:42 +08:00
潘其威(William)
cfa7271617
Update _base.py
fix variable not found error
2023-05-30 07:22:10 +08:00
潘其威(William)
e5771fb206
Update _base.py
fix key mismatch
2023-05-30 06:44:45 +08:00
潘其威(William)
61a4ea035f
Update auto.py
add back save_dir for backward compatible
2023-05-30 06:43:00 +08:00
潘其威(William)
ea74e15199
Update _base.py
add model_name_or_path and model_file_base_name to BaseQuantizeConfig for better model file management; add back save_dir to .from_quantized() for backward compatible
2023-05-30 06:40:31 +08:00
潘其威(William)
0021417050
Update README.md 2023-05-30 05:56:46 +08:00
潘其威(William)
243bb2d56e
Update README_zh.md
fix typo
2023-05-30 05:56:04 +08:00
潘其威(William)
9a10b8496a
Update README_zh.md
merge the example code of downloading from and uploading to HF Hub into simplest usage code above to keep README compact.
2023-05-30 05:53:27 +08:00
潘其威(William)
17db71491f
Update README.md
merge the example code of downloading from and uploading to HF Hub into simplest usage code above to keep README compact.
2023-05-30 05:49:29 +08:00
PanQiWei
9dd7784e6a update lr 2023-05-29 21:27:35 +08:00
PanQiWei
539682e951 update lr 2023-05-28 22:54:53 +08:00
PanQiWei
788128f0a6 rename example scripts 2023-05-28 22:49:01 +08:00
PanQiWei
4296ced96d add example script for AdaptionPrompt peft type 2023-05-28 22:44:32 +08:00
PanQiWei
6c64b0b361 raise NotImplementedError when model with fused attention injected try to use ADAPTION_PROMPT peft type 2023-05-28 22:35:34 +08:00
PanQiWei
7af93624d7 remove useless code 2023-05-28 22:30:38 +08:00
PanQiWei
a1f5204bfd fix ppl calculation error 2023-05-28 22:20:49 +08:00
PanQiWei
def084bf0e reset value of AdaptionPromptConfig.adapter_layers to number of model's hidden layers when exceeds 2023-05-28 22:11:02 +08:00
PanQiWei
7bb01ae1cd remove useless prints 2023-05-28 22:01:35 +08:00
PanQiWei
956029fcb2 update README.md 2023-05-28 21:42:59 +08:00
PanQiWei
271ae2926b add training example scripts for Lora and AdaLora 2023-05-28 21:35:53 +08:00
PanQiWei
801b1c13ca update example script 2023-05-28 21:30:53 +08:00
PanQiWei
ad10c13d40 support AdaLora 2023-05-28 21:30:45 +08:00
PanQiWei
3ee2daa73c make GPTQLoraModel to inherit from LoraModel to simplify code 2023-05-28 17:36:18 +08:00
PanQiWei
22d1d8dcaa add 'auto_find_all_linears' argument to get_gptq_peft_model function 2023-05-28 17:04:38 +08:00
PanQiWei
83132a663a add warning to guide users interact with lora properly 2023-05-28 16:57:31 +08:00
PanQiWei
86f060c74b Merge branch 'main' into peft_integration 2023-05-28 16:23:38 +08:00
PanQiWei
4b8524ffb1 update README 2023-05-27 23:04:00 +08:00
PanQiWei
f703c9ac98 update generation_speed.py 2023-05-27 20:15:43 +08:00
PanQiWei
1ff000658e update README.md 2023-05-27 19:26:56 +08:00
PanQiWei
e8bd3c33c4 add generation speed benchmark example script 2023-05-27 19:16:42 +08:00
PanQiWei
491da62402 fix signature at import time 2023-05-27 17:49:58 +08:00
PanQiWei
0327ac8f42 update README 2023-05-27 17:42:16 +08:00
PanQiWei
3cb1bf5a6d add trust_remote_code command line flag 2023-05-27 17:09:10 +08:00
PanQiWei
c040617a94 update README 2023-05-27 17:03:50 +08:00
PanQiWei
ceacd59e4b update NEWS_OR_UPDATE 2023-05-27 16:37:51 +08:00
潘其威(William)
0a40581270
Merge pull request #111 from PanQiWei/falcon
Falcon support
2023-05-27 16:23:30 +08:00
潘其威(William)
23998345f5
Merge branch 'main' into falcon 2023-05-27 16:23:16 +08:00
潘其威(William)
3108985a55
Merge pull request #112 from billcai/patch-2
Minor syntax fix for auto.py
2023-05-27 16:20:54 +08:00
Bill Cai
0729760234
Update auto.py 2023-05-27 11:16:43 +08:00
潘其威(William)
269ef7335c
Merge branch 'main' into falcon 2023-05-27 08:15:52 +08:00
潘其威(William)
9ebc1d1ec0
Merge pull request #63 from LaaZa/GPTBigCode
Add support for GPTBigCode(starcoder)
2023-05-27 08:03:10 +08:00
潘其威(William)
3c3b0e1e79
Merge branch 'main' into GPTBigCode 2023-05-27 08:03:03 +08:00
潘其威(William)
358ff80d09
Merge pull request #65 from LaaZa/Codegen
Add support for CodeGen/2
2023-05-27 08:01:53 +08:00
潘其威(William)
eab728b263
Merge branch 'main' into Codegen 2023-05-27 08:00:19 +08:00
潘其威(William)
f6fd314d5a
Merge branch 'main' into GPTBigCode 2023-05-27 07:57:25 +08:00
qwopqwop200
277809381b
fix bug 2023-05-27 08:53:47 +09:00
PanQiWei
5bc5325920 add find_all_linear_names help function, make customized lora module more general 2023-05-27 07:49:17 +08:00
PanQiWei
eb9c0b140f update FusedLlamaMLPForQuantizedModel for general usage purpose 2023-05-27 07:47:20 +08:00
qwopqwop200
bcb345fb35
support falcon 2023-05-27 07:53:39 +09:00
qwopqwop200
4d5b4fa5c6
add dtype 2023-05-27 07:49:28 +09:00
qwopqwop200
c14b4c1567
change find layer algorithm 2023-05-27 07:48:50 +09:00
qwopqwop200
874c9fd0ef
fix bug 2023-05-27 07:47:17 +09:00
PanQiWei
f7e705848a move peft compatible model injection to the last step 2023-05-26 14:29:33 +08:00
PanQiWei
8bf21a7e4c set xavier_uniform_ as lora_A's init function 2023-05-26 14:06:53 +08:00
PanQiWei
2b532f9453 add trainable mode 2023-05-26 13:11:30 +08:00
PanQiWei
fe5f5d12ed Merge branch 'main' into peft_integration 2023-05-26 09:48:06 +08:00
PanQiWei
69609c4bc7 support faster vecquant4matmul cuda kernel 2023-05-26 08:55:05 +08:00
PanQiWei
cfd27e8caa refactor file structure of qlinears 2023-05-26 07:18:16 +08:00
潘其威(William)
b4eda619d0
Merge pull request #104 from PanQiWei/triton-float32
triton float32 support
2023-05-25 22:56:00 +08:00
qwopqwop200
503f85255d
Update kernels.py 2023-05-25 23:15:33 +09:00
PanQiWei
f6a34137e9 lora compatibility 2023-05-25 19:44:53 +08:00
PanQiWei
d293bf3a04 first upload peft_utils.py 2023-05-25 15:11:11 +08:00
PanQiWei
4d157a3b64 add hack of __getattr__ 2023-05-25 15:10:33 +08:00
TheBloke
b7bb50b4d5 Fix bug added after merge 2023-05-25 07:05:51 +01:00
Tom Jobbins
492255b400
Merge branch 'main' into TheBloke_support-HF-download 2023-05-25 07:02:13 +01:00
PanQiWei
096749fe9d generalize QuantLinear 2023-05-25 13:33:09 +08:00
PanQiWei
49d1f0da1b update README 2023-05-25 13:06:17 +08:00
PanQiWei
6426b41f94 update setup.py 2023-05-25 13:06:10 +08:00
潘其威(William)
18c7ce5875
Merge pull request #100 from PanQiWei/improve_cpu_offload
Improve CPU offload
2023-05-24 18:48:37 +08:00
PanQiWei
c341a6df2f update tutorial 2023-05-24 18:48:19 +08:00
PanQiWei
ac14180946 update tutorial 2023-05-24 18:31:59 +08:00
PanQiWei
065fd1de35 update README 2023-05-24 18:26:47 +08:00
PanQiWei
e6ba062c08 update basic usage example code 2023-05-24 17:58:01 +08:00
PanQiWei
94ef4d5ada update basic usage example code 2023-05-24 17:56:46 +08:00
PanQiWei
c89bb6450c correct typo of function name 2023-05-24 17:43:38 +08:00
PanQiWei
10347fdd7b remove full_cpu_offload argument and unify model dispatch strategy 2023-05-24 17:41:04 +08:00
PanQiWei
379f24c2a5 remove add_align_logits_hook_to_model 2023-05-24 17:01:57 +08:00
PanQiWei
749dba1a7e disable add_align_logits_hook_to_model for now 2023-05-24 13:42:06 +08:00
PanQiWei
58c1b509f0 support add_align_logits_hook_to_model 2023-05-24 12:50:30 +08:00
PanQiWei
21ab7c435a make comments more readable 2023-05-24 11:38:29 +08:00
PanQiWei
c31b370228 make_sure_not_tensor_in_meta_device before load checkpoint 2023-05-24 11:32:45 +08:00
PanQiWei
63f1b4e073 remove comment 2023-05-24 11:23:07 +08:00
PanQiWei
057c39e3f2 fix meta device bug when use low_cpu_mem_usage 2023-05-24 11:19:59 +08:00
PanQiWei
e2e7809a1f always to enable QuantLinear bias to make compatible with model quantized from other frameworks 2023-05-24 10:56:31 +08:00
PanQiWei
8e034b28bc remove duplicate code 2023-05-23 23:48:15 +08:00
PanQiWei
4373d6b29c Merge branch 'main' into improve_cpu_offload 2023-05-23 23:47:33 +08:00
PanQiWei
191da8141e fix device mismatch 2023-05-23 23:22:52 +08:00
PanQiWei
e4e90e8b0a add warmup_triton method 2023-05-23 23:18:46 +08:00
PanQiWei
ed14d3a786 fix save quantized model failed when load pretrained model using CPU offload 2023-05-23 23:17:11 +08:00
潘其威(William)
7820322089
Merge pull request #66 from LexSong/main
Fix CUDA out of memory error in qlinear_old.py
2023-05-23 23:04:45 +08:00
PanQiWei
6476ee4235 add options: 'low_cpu_mem_usage' and 'full_cpu_offload' 2023-05-23 22:51:00 +08:00
PanQiWei
c63959365a update setup.py 2023-05-23 19:30:47 +08:00
PanQiWei
1b2159bd4c add more help functions 2023-05-23 19:30:28 +08:00
PanQiWei
db63c0876a half out 2023-05-23 16:08:28 +08:00
潘其威(William)
1bb7be3dd3
Update issue templates 2023-05-23 15:55:48 +08:00
潘其威(William)
a85d65e915
Update issue templates 2023-05-23 15:53:07 +08:00
Lex Song
f2ab4fab46 Fix CUDA out of memory error in qlinear_old.py
Add a missing line from qlinear.py to qlinear_old.py to convert the output tensor.
This resolves a CUDA out of memory error that occurred without this line.
2023-05-20 21:10:11 +08:00
TheBloke
bf633c298e Clean up some unused params 2023-05-20 10:32:27 +01:00
潘其威(William)
d4011d29c6
Merge pull request #92 from PanQiWei/fix_triton_integration_bugs
fix ImportError when triton is not installed
2023-05-20 17:01:14 +08:00
潘其威(William)
809efa6fcb
Update README_zh.md 2023-05-20 16:53:27 +08:00
潘其威(William)
082e76713e
Update README.md 2023-05-20 16:52:43 +08:00
潘其威(William)
0ca1752a9b
Merge pull request #93 from TheBloke/TheBloke_rename-quant_cuda2
Rename 'quant_cuda' to 'autogptq_cuda' to avoid conflicts with existing GPTQ-for-LLaMa installations.
2023-05-20 16:44:02 +08:00
PanQiWei
b803369719 update quant_with_alpaca.py 2023-05-20 16:43:21 +08:00
PanQiWei
f78f074409 update quant_with_alpaca.py 2023-05-20 16:42:34 +08:00
TheBloke
898f1ef62d Rename 'quant_cuda' to 'autogptq_cuda' to avoid conflicts with existing GPTQ-for-LLaMa installations. 2023-05-20 09:33:51 +01:00
PanQiWei
73b5952f5e fix not return directly when triton is not installed 2023-05-20 16:21:52 +08:00
PanQiWei
86b3b52c63 fix ImportError when triton is not installed 2023-05-20 16:15:20 +08:00
潘其威(William)
13defe253a
Merge pull request #84 from TheBloke/TheBloke_forward-positional-args
Forward position args to allow `model(tokens)` syntax
2023-05-20 15:04:27 +08:00
潘其威(William)
d0b7908a2c
Merge pull request #82 from Ph0rk0z/patch-1
Update example script to include desc_act
2023-05-20 15:03:18 +08:00
潘其威(William)
1ef0af824a
Merge pull request #80 from PanQiWei/user_customized_device_map
Support users customize `device_map`
2023-05-20 15:00:05 +08:00
TheBloke
277a007ebc Minor clarification and clean up of example script 2023-05-19 18:33:19 +01:00
TheBloke
e5c8479100 Remove debugging print line 2023-05-19 17:50:48 +01:00
TheBloke
c234bf11f9 Update README with examples for HF (Chinese text is from Google Translate - please check! :) ) 2023-05-19 17:39:49 +01:00
TheBloke
735f7df4cc Add push_to_hub for HF hub uploading 2023-05-19 17:10:57 +01:00
TheBloke
908b338436 Initial support for model loading from HF hub 2023-05-19 15:57:05 +01:00
TheBloke
a397f00cc3 Implement HF cached download for quantize_config 2023-05-19 15:15:43 +01:00
Forkoz
cc835640a9
Update some help 2023-05-17 07:31:09 -05:00
Forkoz
6b0b84bc9b
Update basic_usage_gpt_xl.py 2023-05-17 07:28:53 -05:00
Forkoz
2d0aaa423f
update another example 2023-05-17 07:27:49 -05:00
Forkoz
922ec02998
Fix another example 2023-05-17 07:26:24 -05:00
TheBloke
7f165337ed Forward position args to allow syntax 2023-05-16 12:19:52 +01:00
Forkoz
eaac7a7b76
Update example script to include desc_act
It will help with people unwittingly making incompatible models.
2023-05-15 11:26:22 +00:00
潘其威(William)
570867c109
Merge pull request #79 from oobabooga/main
support loading quantized model with .pt file extension
2023-05-15 16:08:44 +08:00
PanQiWei
759d6953d4 support user customize device_map 2023-05-15 13:26:38 +08:00
PanQiWei
07e06fa08c make compatible with older transformers version 2023-05-15 13:26:18 +08:00
oobabooga
86c7021285
Look for .pt files 2023-05-15 00:00:05 -03:00
潘其威(William)
262669112b
Merge pull request #76 from PanQiWei/gptj_fused_attention
Gptj fused attention
2023-05-14 16:21:27 +08:00
PanQiWei
d5429441ef add GPTJ fused attention module 2023-05-14 16:17:21 +08:00
PanQiWei
e1c564ac0e compatible with older pytorch version 2023-05-14 16:17:03 +08:00
PanQiWei
4586b3f31f update setup.py 2023-05-14 16:16:20 +08:00
PanQiWei
5445c67190 add library version comparison help functions 2023-05-14 16:16:06 +08:00
潘其威(William)
bdb08c16fc
Merge branch 'main' into Codegen 2023-05-14 13:10:52 +08:00
潘其威(William)
e24c5122db
Merge branch 'main' into GPTBigCode 2023-05-14 13:10:10 +08:00
潘其威(William)
7c248cebf6
Merge pull request #43 from PanQiWei/faster-llama
Faster llama
2023-05-14 13:09:10 +08:00
PanQiWei
e83c9fc8dd update setup.py 2023-05-14 13:08:26 +08:00
PanQiWei
de33d26d67 fix bugs 2023-05-14 13:07:18 +08:00
PanQiWei
2273f9ef39 refactor file structure for triton kernels 2023-05-14 11:49:10 +08:00
PanQiWei
fef1a4fe4b make code clean and extendable 2023-05-12 20:11:55 +08:00
PanQiWei
d718d63e9c add import_utils.py for commonly used module importation 2023-05-12 19:58:48 +08:00
潘其威(William)
6f887f666a
Update 02-Advanced-Model-Loading-and-Best-Practice.md 2023-05-12 19:47:05 +08:00
LaaZa
fb380fb9c2 Add initial support for MPT 2023-05-12 14:46:52 +03:00
PanQiWei
c5ff195764 skip fused module injection instead of raising error if it's not supported yet. 2023-05-12 19:36:00 +08:00
PanQiWei
f159aeabb6 refactor .from_quantized api and improve model loading strategy 2023-05-12 18:09:50 +08:00
PanQiWei
69610329d2 add _fused_base.py 2023-05-12 18:09:23 +08:00
潘其威(William)
393a2fbac2
Update README.md 2023-05-12 13:47:30 +08:00
潘其威(William)
e5c267e289
Update README.md 2023-05-12 13:46:41 +08:00
潘其威(William)
d6d099a1d1
Merge branch 'main' into faster-llama 2023-05-12 13:39:24 +08:00
PanQiWei
4bb10fda49 groupsize -> group_size 2023-05-12 13:37:52 +08:00
LaaZa
b8187ff05a Add support for CodeGen/2 2023-05-08 17:34:00 +03:00
LaaZa
63247a0669 Add support for GPTBigCode 2023-05-08 12:28:29 +03:00
潘其威(William)
560cf92d7d
Merge pull request #62 from lszxb/fix_incorrect_pack
fix incorrect pack while using cuda, desc_act and grouping
2023-05-08 10:35:30 +08:00
潘其威(William)
8b67f7de2f
Merge pull request #59 from Sciumo/setup_conda
Setup conda
2023-05-08 10:34:10 +08:00
lszxb
174ef81995 fix incorrect pack while using cuda, desc_act and grouping 2023-05-07 20:44:47 +08:00
Sciumo
ee4ca934aa add conda cuda include directory if found 2023-05-05 14:28:04 -04:00
Sciumo
81f3dfe39c add conda cuda include directory if found 2023-05-05 14:27:11 -04:00
qwopqwop200
3ff6ab18cb
Merge branch 'main' into faster-llama 2023-05-06 00:20:29 +09:00
潘其威(William)
7c33fa2fa4
Merge pull request #58 from TheBloke/TheBloke_faster-llama_groupsize_fix
Fix bug caused by 'groupsize' vs 'group_size' and change all code to use 'group_size' consistently
2023-05-05 23:00:59 +08:00
TheBloke
1b3329b399 Fix 'groupsize' -> 'group_size' in all other .py files. I haven't touched any CUDA kernels in case there's any complexity there I don't understand 2023-05-05 14:44:16 +01:00
TheBloke
f61ce12271 Change 'groupsize' to 'group_size' everywhere. Turns out this is easier than 'groupsize' due to dependencies in other files. 2023-05-05 13:36:00 +01:00
TheBloke
f64c71e779 Change referenes to 'group_size' to 'groupsize' to match rest of this file 2023-05-05 13:21:13 +01:00
PanQiWei
374ce21066 release v0.1.0 2023-05-05 00:18:50 +08:00
PanQiWei
753c261388 update README.md 2023-05-05 00:15:33 +08:00
PanQiWei
d79aec7bd0 update README.md 2023-05-04 23:06:32 +08:00
PanQiWei
fe3456100c bug fix from commit 3c108d4232 2023-05-04 22:34:16 +08:00
PanQiWei
e4d476be16 update README.md 2023-05-04 22:17:38 +08:00
PanQiWei
6cba6e7123 reformat code 2023-05-04 22:16:08 +08:00
PanQiWei
1c6bb69fae fix attribute name error 2023-05-04 22:10:33 +08:00
潘其威(William)
771b650a7c
Merge pull request #38 from PanQiWei/faster-cuda-no-actorder
Faster cuda no actorder
2023-05-04 21:47:19 +08:00
qwopqwop200
b19c59541b
fix bug 2023-05-04 13:17:10 +09:00
qwopqwop200
908248114e
fix bug 2023-05-04 13:15:52 +09:00
qwopqwop200
b14d42e68a
bug fix 2023-05-04 13:03:38 +09:00
qwopqwop200
b0bc0b0358
bug fix 2023-05-04 13:03:11 +09:00
qwopqwop200
208d660920
fix bug 2023-05-04 10:04:00 +09:00
qwopqwop200
f51a92ed79
support faster and model load strict 2023-05-04 09:53:28 +09:00
qwopqwop200
cc992c21bd
Merge branch 'faster-cuda-no-actorder' into faster-llama 2023-05-04 09:09:09 +09:00
qwopqwop200
d49281bc5d
support faster and model load strict 2023-05-04 09:07:34 +09:00
qwopqwop200
c8504f0660
support faster and model load strict 2023-05-04 09:06:52 +09:00
qwopqwop200
34201dbff9
support faster and model load strict 2023-05-04 09:05:07 +09:00
qwopqwop200
c359f672a8
support faster and model load strict 2023-05-04 09:04:07 +09:00
qwopqwop200
afe1323b3f
support faster and model load strict 2023-05-04 09:03:36 +09:00
qwopqwop200
a88cd16d65
fix bug 2023-05-03 22:36:14 +09:00
qwopqwop200
24251d1397
check kwargs 2023-05-02 22:32:54 +09:00
qwopqwop200
26581b6946
remove LlamaGPTQForCausalLM 2023-05-02 22:18:17 +09:00
qwopqwop200
694f2954a3
add auto model parameter 2023-05-02 22:16:23 +09:00
qwopqwop200
ccd87e5800
add Auto model parameter 2023-05-02 22:15:56 +09:00
qwopqwop200
d8707f92a9
support fused_attn 2023-05-02 21:54:15 +09:00
qwopqwop200
61c6f6a5d2
typo fix 2023-05-02 21:53:39 +09:00
qwopqwop200
a11d59f6c4
support fused_attn 2023-05-02 21:53:13 +09:00
qwopqwop200
f47322f073
fix bug 2023-05-02 21:14:27 +09:00
qwopqwop200
41f2379850
bug fix 2023-05-02 20:38:17 +09:00
qwopqwop200
d2f48e5311
bug fix 2023-05-02 20:36:53 +09:00
qwopqwop200
709bd7594f
Merge pull request #44 from PanQiWei/fix-bug-cuda
Fix bug cuda
2023-05-02 19:50:59 +09:00
qwopqwop200
9490a98444
add LlamaGPTQForCausalLM 2023-05-02 19:32:18 +09:00
qwopqwop200
a6d4f5c091
fix bug 2023-05-02 19:19:04 +09:00
qwopqwop200
2ba84fbb48
fix bug 2023-05-02 19:13:40 +09:00
qwopqwop200
1388acac94
fix bug 2023-05-02 19:13:13 +09:00
qwopqwop200
6c23e5b3a5
add fused mlp ,fused attn 2023-05-02 18:55:44 +09:00
qwopqwop200
f51f763fde
fused attn ,fused mlp apply 2023-05-02 18:51:04 +09:00
qwopqwop200
50c0fd13c5
Multi-GPU, allocate output tensor 2023-05-02 17:51:41 +09:00
qwopqwop200
3c108d4232
fix bug 2023-05-02 12:00:50 +09:00
潘其威(William)
144bd80436
Merge pull request #39 from TheBloke/TheBloke_check_model_exists
Check that model_save_name exists before trying to load it, to avoid confusing checkpoint error
2023-05-01 19:55:24 +08:00
潘其威(William)
d612b0d2f4
Merge pull request #37 from PanQiWei/bug-fix
bug fix quantization demo
2023-05-01 19:36:43 +08:00
潘其威(William)
69bcdfbbf2
Merge pull request #40 from TheBloke/TheBloke_fix_typo
Fix typo: 'hole' -> 'whole'
2023-05-01 19:34:30 +08:00
TheBloke
593a0b28bb Fix typo: 'hole' -> 'whole' 2023-05-01 10:25:18 +01:00
TheBloke
60195ca5f2 Check that model_save_name exists before trying inference, to avoid confusing checkpoint error 2023-05-01 10:15:13 +01:00
qwopqwop200
f0f37c1fe7
fix bug 2023-05-01 18:09:39 +09:00
qwopqwop200
95e633a597
add old cuda 2023-05-01 13:05:14 +09:00
qwopqwop200
5a69e22a93
add qlinear_old 2023-05-01 13:04:47 +09:00
qwopqwop200
9dfcac8e26
add qlinear_old 2023-05-01 13:03:57 +09:00
qwopqwop200
a982d400fb
add old k matmul 2023-05-01 13:02:42 +09:00
qwopqwop200
64e8d93e31
add old vecquantKmatmul 2023-05-01 13:01:37 +09:00
qwopqwop200
d986a738e1
bug fix quantization demo 2023-05-01 08:03:11 +09:00
PanQiWei
e2c7cd4fb3 update README.md 2023-04-30 16:52:15 +08:00
PanQiWei
da6981ac38 add tutorial: advanced-model-loading-and-best-practice 2023-04-30 16:46:48 +08:00
PanQiWei
d8d77199e9 Merge remote-tracking branch 'origin/main' 2023-04-30 16:04:45 +08:00
PanQiWei
1749477bba update README.md 2023-04-30 16:04:30 +08:00
潘其威(William)
e3bc7f220f
Update README.md 2023-04-30 08:38:22 +08:00
PanQiWei
bc7a05b576 update README.md 2023-04-29 22:46:59 +08:00
PanQiWei
fa89cf4872 update README.md 2023-04-29 22:45:52 +08:00
PanQiWei
ff20acaafe add quick start tutorial 2023-04-29 22:31:42 +08:00
潘其威(William)
a535c5727b
Merge pull request #34 from PanQiWei/change-save-name
Change default quantized model save basename
2023-04-29 20:37:24 +08:00
潘其威(William)
5fa803334d
Merge branch 'main' into change-save-name 2023-04-29 20:36:45 +08:00
潘其威(William)
1a3748db2a
Merge pull request #33 from z80maniac/customize-params
Allow to load arbitrary quantized models
2023-04-29 20:35:41 +08:00
qwopqwop200
787909084f
fix bug 2023-04-29 19:08:34 +09:00
qwopqwop200
a2ef4b98db
change save the name 2023-04-29 18:20:46 +09:00
qwopqwop200
9317af6c40
change save name 2023-04-29 18:19:13 +09:00
qwopqwop200
605f345135
Update basic_usage_gpt_xl.py 2023-04-29 18:18:50 +09:00
qwopqwop200
eb5a27f48c
Update basic_usage_wikitext2.py 2023-04-29 18:18:35 +09:00
qwopqwop200
3b74c9758e
change save name 2023-04-29 18:18:21 +09:00
qwopqwop200
b5eb906ac9
change save name 2023-04-29 18:17:59 +09:00
qwopqwop200
05733ae482
change save the name 2023-04-29 18:17:33 +09:00
qwopqwop200
1792cd1111
change save the name 2023-04-29 18:16:48 +09:00
ZXED
24a371d14a
use the same Optional style as in other params 2023-04-29 09:52:11 +03:00
ZXED
c22770188d
allow user to set trust_remote_code flag manually 2023-04-29 09:52:11 +03:00
ZXED
b3f19a7ba7
support custom model name when loading the model 2023-04-29 09:52:11 +03:00
ZXED
ea8ab73343
support custom quantize_config when loading the model 2023-04-29 09:51:50 +03:00
PanQiWei
16d8dd200f remove non-parameters module from MOSSGPTQForCausalLM.outside_layer_modules 2023-04-29 10:58:29 +08:00
PanQiWei
d9e7363fa8 update README.md 2023-04-29 01:47:40 +08:00
PanQiWei
b490ab004e remove override of _resize_attention_mask for llama and opt 2023-04-28 23:08:42 +08:00
潘其威(William)
1d91fded6c
Merge pull request #31 from PanQiWei/add-raise-exception-and-gpt2-xl-example-add
Add example to use with GPT2-XL and raise an error when infeature and outfeature are not be divisible by 256
2023-04-28 22:52:18 +08:00
qwopqwop200
ae8b1a22a3
change global to local 2023-04-28 23:18:39 +09:00
qwopqwop200
e914b9b1bd
update support 256 not div 2023-04-28 22:48:23 +09:00
qwopqwop200
c9215a1b5b
change div num 2023-04-28 22:42:29 +09:00
qwopqwop200
7a38c2a6ef
add basic_usage_gpt_xl 2023-04-28 22:32:46 +09:00
qwopqwop200
19f167e58b
add raise-exception 2023-04-28 22:24:44 +09:00
潘其威(William)
1e353a8dc5
Merge pull request #24 from PanQiWei/speedup_quantization
Offloading and Multiple devices quantization/inference
2023-04-28 18:50:12 +08:00
PanQiWei
5055d785b6 update README.md 2023-04-28 18:44:06 +08:00
PanQiWei
3f761568a8 update example code 2023-04-28 18:26:16 +08:00
PanQiWei
bdb713b5a3 add batch_size to model.quant() api 2023-04-28 18:26:07 +08:00
PanQiWei
41564a48db make data_utils.py as global utils 2023-04-28 18:08:58 +08:00
PanQiWei
e189d91004 update example code 2023-04-28 17:56:52 +08:00
PanQiWei
3dfc87bec3 return module in .to function 2023-04-28 17:20:46 +08:00
PanQiWei
789f821e6c update example code 2023-04-28 17:17:38 +08:00
PanQiWei
a69a73a22c fix device mismatch when directly using model to inference after quantization 2023-04-28 16:41:46 +08:00
PanQiWei
892eeb40e0 update example code 2023-04-28 16:23:02 +08:00
潘其威(William)
092bd502d0
Merge pull request #29 from PanQiWei/fix-bug-speedup_quant
Fix bug speedup quant and support gpt2
2023-04-28 12:02:55 +08:00
qwopqwop200
c4f82ac628
Merge pull request #30 from PanQiWei/add-gpt2
Add gpt2
2023-04-28 09:19:20 +09:00
qwopqwop200
329a64ed40
support conv1d,conv2d 2023-04-28 09:15:42 +09:00
qwopqwop200
bb9afe8b61
support conv1d,conv2d 2023-04-28 09:15:13 +09:00
qwopqwop200
c1b7c7647d
support conv1d 2023-04-28 09:14:44 +09:00
qwopqwop200
ac41f68532
add gpt2 2023-04-28 09:14:05 +09:00
qwopqwop200
dad249990c
add gpt2 2023-04-28 09:13:22 +09:00
qwopqwop200
435eebee4b
support conv1d,conv2d 2023-04-28 09:13:00 +09:00
qwopqwop200
cc0f71a568
add gpt2 2023-04-28 09:11:50 +09:00
qwopqwop200
3f90a22632
fix bug 2023-04-28 08:26:58 +09:00
qwopqwop200
9c38393e31
fix bug about wf meta device 2023-04-28 08:26:11 +09:00
PanQiWei
d0cd5af5d3 make code more robust 2023-04-28 01:29:12 +08:00
PanQiWei
51d2e53130 add support to cpu offloading and multi gpus inference on quantized model 2023-04-28 00:53:57 +08:00
PanQiWei
9c44365836 Merge remote-tracking branch 'origin/speedup_quantization' into speedup_quantization 2023-04-28 00:31:03 +08:00
PanQiWei
48737db0ff update setup.py 2023-04-28 00:30:51 +08:00
PanQiWei
3aac76e71a align 'from_pretrained' api 2023-04-28 00:30:05 +08:00
PanQiWei
b14dca9207 disk offload assertion 2023-04-27 21:31:53 +08:00
PanQiWei
7a3397e7ba add cpu offload when doing quantization 2023-04-27 21:25:24 +08:00
PanQiWei
ac3f7054e0 big fix 2023-04-27 19:33:25 +08:00
PanQiWei
498de923f2 support multi gpus quantization 2023-04-27 18:48:43 +08:00
潘其威(William)
a21205f609
Merge pull request #23 from PanQiWei/add-option
add option
2023-04-27 16:51:33 +08:00
qwopqwop200
8b6ee04aee
add option 2023-04-27 17:29:36 +09:00
PanQiWei
211ffe94f9 Merge branch 'main' into speedup_quantization 2023-04-27 12:15:05 +08:00
PanQiWei
c4ad7b7630 update README.md 2023-04-27 11:49:23 +08:00
PanQiWei
c9bb427546 align 'from_pretrained' api 2023-04-27 02:29:32 +08:00
PanQiWei
a2abff983e support dispatch layers to different devices when loading pretrained model before quantization 2023-04-27 02:24:08 +08:00
PanQiWei
950f203260 add 'n_positions' to sequence length search list 2023-04-27 01:09:10 +08:00
PanQiWei
8cc814bd81 fix installation failed on cpu only device 2023-04-26 21:26:47 +08:00
PanQiWei
893c3264cb make layer ignorance more robust 2023-04-26 19:35:19 +08:00
PanQiWei
f7fc3ab67b set version to v0.1.0-dev 2023-04-26 18:04:49 +08:00
PanQiWei
e0e0eb1136 set version to v0.1.0-dev0 2023-04-26 18:04:13 +08:00
118 changed files with 20326 additions and 1526 deletions

33
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View file

@ -0,0 +1,33 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Hardware details**
Information about CPU and GPU, such as RAM, number, etc.
**Software version**
Version of relevant software such as operation system, cuda toolkit, python, auto-gptq, pytorch, transformers, accelerate, etc.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.

10
.github/ISSUE_TEMPLATE/custom.md vendored Normal file
View file

@ -0,0 +1,10 @@
---
name: Custom issue template
about: Describe this issue template's purpose here.
title: ''
labels: ''
assignees: ''
---

View file

@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[FEATURE]"
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

69
.github/workflows/build_wheels_cuda.yml vendored Normal file
View file

@ -0,0 +1,69 @@
name: Build AutoGPTQ Wheels with CUDA
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA ${{ matrix.cuda }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, windows-latest]
pyver: ["3.8", "3.9", "3.10", "3.11"]
cuda: ["11.7", "11.8"]
defaults:
run:
shell: pwsh
env:
CUDA_VERSION: ${{ matrix.cuda }}
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.pyver }}
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Install Dependencies
run: |
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
python -m pip install --upgrade build setuptools wheel ninja
- name: Build Wheel
run: |
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
# TODO: remove this
if (!$IsLinux) {$env:INCLUDE_EXLLAMA_KERNELS = 0}
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
if: runner.os == 'Linux'
with:
name: 'linux-cuda-wheels'
path: ./dist/*.whl
- uses: actions/upload-artifact@v3
if: runner.os == 'Windows'
with:
name: 'windows-cuda-wheels'
path: ./dist/*.whl

74
.github/workflows/build_wheels_pypi.yml vendored Normal file
View file

@ -0,0 +1,74 @@
name: Build AutoGPTQ Wheels for PyPI with CUDA
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA 11.7
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, windows-latest]
pyver: ["3.8", "3.9", "3.10", "3.11"]
defaults:
run:
shell: pwsh
env:
CUDA_VERSION: "11.7"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.pyver }}
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Install Dependencies
run: |
conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"
conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia
python -m pip install --upgrade build setuptools wheel ninja
- name: Build Wheel
run: |
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH}
$env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
$env:PYPI_RELEASE = "1"
echo "CUDA_PATH:"
echo $env:CUDA_PATH
echo "PYPI_RELEASE:"
echo $env:PYPI_RELEASE
python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
if: runner.os == 'Linux'
with:
name: 'linux-cuda-wheels'
path: ./dist/*.whl
- uses: actions/upload-artifact@v3
if: runner.os == 'Windows'
with:
name: 'windows-cuda-wheels'
path: ./dist/*.whl

103
.github/workflows/build_wheels_rocm.yml vendored Normal file
View file

@ -0,0 +1,103 @@
name: Build AutoGPTQ Wheels with ROCm
on: workflow_dispatch
jobs:
build_wheels:
if: ${{ github.repository_owner == 'PanQiWei' }}
strategy:
matrix:
os: [ubuntu-20.04]
python: ["3.8", "3.9", "3.10", "3.11"]
rocm: ["5.4.2"] # , "5.5", "5.6"]
name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and RoCm ${{ matrix.rocm }}
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
- name: Free disk space
run: |
df -h
echo "Removing large packages"
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel
df -h
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get clean
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get autoclean -y >/dev/null 2>&1
df -h
echo "https://github.com/actions/virtual-environments/issues/709"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
df -h
echo "remove big /usr/local"
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
df -h
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
sudo rm -rf /usr/share/swift > /dev/null 2>&1
df -h
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.python }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Set up environment
run: |
echo "Using python:"
python --version
which python
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.5" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.5.50500-1_all.deb
else
export ROCM_DL_FILE=amdgpu-install_5.6.50600-1_all.deb
fi
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
sudo dpkg -i $ROCM_DL_FILE
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
python -m pip install --upgrade build setuptools wheel ninja
python -m pip install torch --index-url https://download.pytorch.org/whl/rocm${{ matrix.rocm }}
- name: Build wheels
run: |
echo "Using python for build:"
python --version
which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v3
with:
name: 'linux-rocm-wheels'
path: ./dist/*.whl

160
.gitignore vendored Normal file
View file

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

244
README.md
View file

@ -1,90 +1,130 @@
# AutoGPTQ
An easy-to-use model quantization package with user-friendly apis, based on GPTQ algorithm.
<h1 align="center">AutoGPTQ</h1>
<p align="center">An easy-to-use LLMs quantization package with user-friendly apis, based on GPTQ algorithm.</p>
<p align="center">
<a href="https://github.com/PanQiWei/AutoGPTQ/releases">
<img alt="GitHub release" src="https://img.shields.io/github/release/PanQiWei/AutoGPTQ.svg">
</a>
<a href="https://pypi.org/project/auto-gptq/">
<img alt="PyPI - Downloads" src="https://img.shields.io/pypi/dd/auto-gptq">
</a>
</p>
<h4 align="center">
<p>
<b>English</b> |
<a href="https://github.com/PanQiWei/AutoGPTQ/blob/main/README_zh.md">中文</a>
</p>
</h4>
## The path to v1.0.0
Hi, fellow community members, long time no see! I'm sorry that I haven't been able to update this project more frequently due to personal reasons during this period. The past few weeks have been huge in terms of my career plans. Not long ago, I officially bid farewell to the startup team that I joined for two years after graduation. I'm very grateful to the leaders and colleagues of the team for their trust and guidance, which enabled me to grow rapidly in two years; at the same time, I'm also really grateful to the team for allowing me to use the internal A100 GPU server cluster free of charge since the start of the AutoGPTQ project to complete various experiments and performance evaluations. (Of course, it can no longer be used in the future, so **it will mean a lot to me if there will be new hardware sponsorship!**) In the past two years, I have served as an AI engineer in this team, responsible for the LLM based dialogue system's architecture design and develop. We had successfully launched a product called gemsouls, but unfortunately it has ceased operations. Now, the team is about to launch a new product called [modelize](https://www.beta.modelize.ai/), which is **a LLM-native AI agent platform, where users can use multiple AI agents to build a highly automated team, allowing them to interact with each other in the workflow, collaborate to complete complex projects efficiently.**
Getting back to the topic, I'm very excited to see that in the past few months, research on optimizing the inference performance of LLMs has made tremendous progress. Now we can not only complete the inference of LLMs on high-end GPUs efficiently, but also on CPUs and even edge devices. A series of technological advancements make me eager to make more contributions to the open source community. Therefore, I will first use about four weeks to gradually update AutoGPTQ to the v1.0.0 official version. During this period, there will also be 2~3 minor versions are released to allow users to experience performance optimization and new features timely. In my vision, **by the time v1.0.0 is officially released, AutoGPTQ will be able to serve as an extendable and flexible quantization backend that supports all GPTQ-like methods and automatically quantize LLMs written by Pytorch**. I detailed the development plan in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/348), feel free to drop in there for discussion and give your suggestions!
## News or Update
- 2023-04-26 - (Update) - Using `triton` to speed up inference is now supported.
- 2023-04-25 - (News&Update) - [MOSS](https://github.com/OpenLMLab/MOSS) is an open-source tool-augmented conversational language model from Fudan University, quantization is now supported in AutoGPTQ.
- 2023-04-23 - (Update) - Support evaluation on multiple (down-stream) tasks such as: language-modeling, text-classification, text-summarization.
- 2023-04-22 - (News) - qwopqwop200's [AutoGPTQ-triton](https://github.com/qwopqwop200/AutoGPTQ-triton) provides faster speed to integrate with quantized model, for everyone who can access to triton, try and enjoy yourself!
- 2023-04-20 - (News) - AutoGPTQ is automatically compatible with Stability-AI's newly released `gpt_neox` type model family [StableLM](https://github.com/Stability-AI/StableLM).
- 2023-04-16 - (Update) - Support quantization and inference for `bloom`, `gpt_neox`, `gptj`, `llama` and `opt`.
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
- 2023-05-30 - (Update) - Support download/upload quantized model from/to 🤗 Hub.
*For more histories please turn to [here](docs/NEWS_OR_UPDATE.md)*
## Performance Comparison
### Inference Speed
> The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
>
> The quantized model is loaded using the setup that can gain the fastest inference speed.
| model | GPU | num_beams | fp16 | gptq-int4 |
|---------------|---------------|-----------|-------|-----------|
| llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
| llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
| moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
| moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
| moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
| moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
| gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
| gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
### Perplexity
For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#result) and [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#gptq-vs-bitsandbytes)
## Installation
### Quick Installation
You can install the latest stable release of AutoGPTQ from pip:
```shell
pip install auto-gptq
```
By default, pytorch extensions will be installed when `torch` is already in your virtual environment, if you don't want to use cuda extensions, using:
```shell
BUILD_CUDA_EXT=0 pip install auto-gptq
```
For some people want to try LLaMa and whose `transformers` version not meet the newest one that supports it, using:
```shell
pip install auto-gptq[llama]
```
To integrate with `triton`, using:
> warning: 3-bit quantization is not supported when using triton
You can install the latest stable release of AutoGPTQ from pip with pre-built wheels compatible with PyTorch 2.0.1:
```shell
pip install auto-gptq[triton]
```
* For CUDA 11.7: `pip install auto-gptq`
* For CUDA 11.8: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
* For RoCm 5.4.2: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
**Warning:** These wheels are not expected to work on PyTorch nightly. Please install AutoGPTQ from source when using PyTorch nightly.
AutoGPTQ can be installed with the Triton dependency with `pip install auto-gptq[triton]` in order to be able to use the Triton backend (currently only supports linux, no 3-bits quantization).
### Install from source
Clone the source code:
```shell
git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
```
Then, install from source:
```shell
pip install .
pip install -v .
```
Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch extension building.
You can set `BUILD_CUDA_EXT=0` to disable pytorch extension building, but this is **strongly discouraged** as AutoGPTQ then falls back on a slow python implementation.
Use `.[llama]` if you want to try LLaMa model.
To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
```
ROCM_VERSION=5.6 pip install -v .
```
For RoCm systems, the packages `rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev` are required to build.
## Supported Models
Currently, `auto_gptq` supports: `bloom`, `gpt_neox`, `gptj`, `llama`, `moss` and `opt`; more CausalLMs will come soon!
## Quick Tour
## Supported Evaluation Tasks
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
### Quantization and Inference
> warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
## Usage
### Basic
> warning: this is just a show case of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, thus may not performs as well as expected in LLMs.
Below is an example for the simplest use of auto_gptq:
Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
```python
from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
example = tokenizer(
"auto_gptq is a useful tool that can automatically compress model into 4-bit or even higher rate by using GPTQ algorithm.",
return_tensors="pt"
)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
)
]
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
)
# load un-quantized model, the model will always be force loaded into cpu
# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize([example], use_triton=False)
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)
# save quantized model
model.save_quantized(quantized_model_dir)
@ -92,19 +132,41 @@ model.save_quantized(quantized_model_dir)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
# load quantized model, currently only support cpu or single gpu
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
# push quantized model to Hugging Face Hub.
# to use use_auth_token=True, Login first via huggingface-cli login.
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
# alternatively you can save and push at the same time
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
# download quantized model from Hugging Face Hub and load to the first GPU
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to("cuda:0"))[0]))
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
# or you can also use pipeline
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
print(pipeline("auto_gptq is")[0]["generated_text"])
print(pipeline("auto-gptq is")[0]["generated_text"])
```
For more advanced features of model quantization, please reference to [this script](examples/quantization/quant_with_alpaca.py)
### Customize Model
Below is an example to extend `auto_gptq` to support `OPT` model, as you will see, it's very easy:
<details>
<summary>Below is an example to extend `auto_gptq` to support `OPT` model, as you will see, it's very easy:</summary>
```python
from auto_gptq.modeling import BaseGPTQForCausalLM
@ -118,8 +180,8 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
"model.decoder.project_in", "model.decoder.final_layer_norm"
]
# chained attribute names of linear layers in transformer layer module
# normally, there are four sub lists, for each one the modules in it can be seen as one operation,
# and the order should be the order when they are truly executed, in this case (and usually in most cases),
# normally, there are four sub lists, for each one the modules in it can be seen as one operation,
# and the order should be the order when they are truly executed, in this case (and usually in most cases),
# they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
@ -127,21 +189,20 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
["fc1"],
["fc2"]
]
@staticmethod
# the overriding of this method may not necessary for most other models
def _resize_attention_mask(attention_mask):
attention_mask = [each.unsqueeze(1) for each in attention_mask]
return attention_mask
```
After this, you can use `OPTGPTQForCausalLM.from_pretrained` and other functions
After this, you can use `OPTGPTQForCausalLM.from_pretrained` and other methods as shown in Basic.
</details>
### Evaluation on Downstream Tasks
One can use tasks defined in `auto_gptq.eval_tasks` to evaluate model's performance on specific down-stream task before and after quantization.
You can use tasks defined in `auto_gptq.eval_tasks` to evaluate model's performance on specific down-stream task before and after quantization.
The predefined tasks support all causal-language-models implemented in [Hugging Face transformers](https://github.com/huggingface/transformers) and in this project.
The predefined tasks support all causal-language-models implemented in [🤗 transformers](https://github.com/huggingface/transformers) and in this project.
<details>
<summary>Below is an example to evaluate `EleutherAI/gpt-j-6b` on sequence-classification task using `cardiffnlp/tweet_sentiment_multilingual` dataset:</summary>
Below is an example to evaluate `EleutherAI/gpt-j-6b` on sequence-classification task using `cardiffnlp/tweet_sentiment_multilingual` dataset:
```python
from functools import partial
@ -191,14 +252,14 @@ task = SequenceClassificationTask(
"num_samples": 1000, # how many samples will be sampled to evaluation
"sample_max_len": 1024, # max tokens for each sample
"block_max_len": 2048, # max tokens for each data block
# function to load dataset, one must only accept data_name_or_path as input
# function to load dataset, one must only accept data_name_or_path as input
# and return datasets.Dataset
"load_fn": partial(datasets.load_dataset, name="english"),
# function to preprocess dataset, which is used for datasets.Dataset.map,
"load_fn": partial(datasets.load_dataset, name="english"),
# function to preprocess dataset, which is used for datasets.Dataset.map,
# must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name]
"preprocess_fn": ds_refactor_fn,
"preprocess_fn": ds_refactor_fn,
# truncate label when sample's length exceed sample_max_len
"truncate_prompt": False
"truncate_prompt": False
}
)
@ -217,13 +278,46 @@ print(
)
```
### More Examples
For more examples, please turn to [examples](examples/README.md)
</details>
## Side Notes
### VRAM
Currently, I put everything (data, model, etc.) into CPU util one is required to be used or executed on GPU (and will back to CPU once the execution finished). Though I didn't run any benchmark to this date, but the maximum VRAM usage for GPTJ is about 6GB, which may be considered as a reference.
## Learn More
[tutorials](docs/tutorial) provide step-by-step guidance to integrate `auto_gptq` with your own project and some best practice principles.
[examples](examples/README.md) provide plenty of example scripts to use `auto_gptq` in different ways.
## Supported Models
> you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`.
>
> for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`.
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
|------------------------------------|--------------|-----------|-----------|---------------|-------------------------------------------------------------------------------------------------|
| bloom | ✅ | ✅ | ✅ | ✅ | |
| gpt2 | ✅ | ✅ | ✅ | ✅ | |
| gpt_neox | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| gptj | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| llama | ✅ | ✅ | ✅ | ✅ | ✅ |
| moss | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| opt | ✅ | ✅ | ✅ | ✅ | |
| gpt_bigcode | ✅ | ✅ | ✅ | ✅ | |
| codegen | ✅ | ✅ | ✅ | ✅ | |
| falcon(RefinedWebModel/RefinedWeb) | ✅ | ✅ | ✅ | ✅ | |
## Supported Evaluation Tasks
Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
## Running tests
Tests can be run with:
```
pytest tests/ -s
```
## Acknowledgement
- Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
- Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).
[![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date)

334
README_zh.md Normal file
View file

@ -0,0 +1,334 @@
<h1 align="center">AutoGPTQ</h1>
<p align="center">一个基于 GPTQ 算法,简单易用且拥有用户友好型接口的大语言模型量化工具包。</p>
<p align="center">
<a href="https://github.com/PanQiWei/AutoGPTQ/releases">
<img alt="GitHub release" src="https://img.shields.io/github/release/PanQiWei/AutoGPTQ.svg">
</a>
<a href="https://pypi.org/project/auto-gptq/">
<img alt="PyPI - Downloads" src="https://img.shields.io/pypi/dd/auto-gptq">
</a>
</p>
<h4 align="center">
<p>
<a href="https://github.com/PanQiWei/AutoGPTQ/blob/main/README.md">English</a> |
<b>中文</b>
</p>
</h4>
## 通向 v1.0.0 之路
嗨,社区的伙伴们,好久不见!很抱歉这段时间由于个人原因,我没能以较高的频率来更新这个项目。过去几周对我的职业生涯规划而言意义重大。在不久前,我正式告别了毕业后便加入两年之久的创业团队,非常感谢团队的领导和同事们给予我的信任与指导,让我能够在两年时间里飞速地成长;同时也十分感激团队允许我自 AutoGPTQ 项目创立以来一直无偿使用内部的 A100 GPU 服务器集群以完成各项实验与性能测评。(当然今后是无法继续使用了,因此**若有新的硬件赞助我将感激不尽**!)过去的两年里,我在这个团队中担任算法工程师的角色,负责基于大语言模型的对话系统架构设计与开发,我们曾成功推出一款名为 gemsouls 的产品,但不幸的是它已经停止运营。而现在,这个团队即将推出一款名为 [modelize](https://www.beta.modelize.ai/) 的新产品,**这是一个大模型原生的 AI 智能体平台,用户可以使用多个 AI 智能体搭建一个高度自动化的团队,让它们在工作流中相互合作,高效完成复杂的项目。**
话归正题,我非常兴奋地看到,在过去几个月的时间里,针对大语言模型推理性能优化的研究取得了巨大的进展,如今我们不仅能够在高端显卡上完成大语言模型的推理,甚至在 CPU 和边缘设备上都可以轻松运行大语言模型。一系列的技术进步,让我同样迫不及待地在开源社区上做出更多的贡献,因此,首先,我将用约四周的时间将 AutoGPTQ 迭代至 v1.0.0 正式版本,在此期间,也会有 2~3 个小版本发布以让用户能够及时体验性能优化和新特性。在我的愿景里,**到 v1.0.0 版本正式发布时AutoGPTQ 将能够作为一个灵活可拓展的、支持所有 GPTQ-like 方法的量化后端,自动地完成各种基于 Pytorch 编写的大语言模型的量化工作**。我在[这里](https://github.com/PanQiWei/AutoGPTQ/issues/348)详细介绍了开发计划,欢迎移步至此进行讨论并给出你们的建议!
## 新闻或更新
- 2023-08-23 - (新闻) - 🤗 Transformers、optimum 和 peft 完成了对 `auto-gptq` 的集成,现在使用 GPTQ 模型进行推理和训练将变得更容易!阅读 [这篇博客](https://huggingface.co/blog/gptq-integration) 和相关资源以了解更多细节!
- 2023-08-21 - (新闻) - 通义千问团队发布了基于 `auto-gptq` 的 Qwen-7B 4bit 量化版本模型,并提供了[详尽的测评结果](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (更新) - 支持 exllama 的 q4 CUDA 算子使得 int4 量化模型能够获得至少1.3倍的推理速度提升.
- 2023-08-04 - (更新) - 支持 RoCm 使得 AMD GPU 的用户能够使用 auto-gptq 的 CUDA 拓展.
- 2023-07-26 - (更新) - 一个优雅的 [PPL 测评脚本](examples/benchmark/perplexity.py)以获得可以与诸如 `llama.cpp` 等代码库进行公平比较的结果。
- 2023-06-05 - (更新) - 集成 🤗 peft 来使用 gptq 量化过的模型训练适应层,支持 LoRAAdaLoRAAdaptionPrompt 等。
- 2023-05-30 - (更新) - 支持从 🤗 Hub 下载量化好的模型或上次量化好的模型到 🤗 Hub。
*获取更多的历史信息,请转至[这里](docs/NEWS_OR_UPDATE.md)*
## 性能对比
### 推理速度
> 以下结果通过[这个脚本](examples/benchmark/generation_speed.py)生成,文本输入的 batch size 为1解码策略为 beam search 并且强制模型生成512个 token速度的计量单位为 tokens/s越大越好
>
> 量化模型通过能够最大化推理速度的方式加载。
| model | GPU | num_beams | fp16 | gptq-int4 |
|---------------|---------------|-----------|-------|-----------|
| llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
| llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
| moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
| moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
| moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
| moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
| gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
| gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
### 困惑度PPL
对于困惑度的对比, 你可以参考 [这里](https://github.com/qwopqwop200/GPTQ-for-LLaMa#result) 和 [这里](https://github.com/qwopqwop200/GPTQ-for-LLaMa#gptq-vs-bitsandbytes)
## 安装
### 快速安装
你可以通过 pip 来安装与 PyTorch 2.0.1 相兼容的最新稳定版本的 AutoGPTQ 的预构建轮子文件:
* 对于 CUDA 11.7 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/`
* 对于 CUDA 11.8 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
* 对于 RoCm 5.4.2 `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
**警告:** 预构建的轮子文件不一定在 PyTorch 的 nightly 版本上有效。如果要使用 PyTorch 的 nightly 版本,请从源码安装 AutoGPTQ。
#### 取消 cuda 拓展的安装
默认情况下,在 `torch``cuda` 已经于你的机器上被安装时cuda 拓展将被自动安装,如果你不想要这些拓展的话,采用以下安装命令:
```shell
BUILD_CUDA_EXT=0 pip install auto-gptq
```
同时为确保该拓展——`autogptq_cuda` 不再存在于你的虚拟环境,执行以下命令:
```shell
pip uninstall autogptq_cuda -y
```
#### 支持使用 triton 加速
若想使用 `triton` 加速模型推理,使用以下命令:
> 警告:目前 triton 仅支持 linux 操作系统;当使用 triton 时 3-bit 数值类型的量化将不被支持
```shell
pip install auto-gptq[triton]
```
### 从源码安装
<details>
<summary>点击以查看详情</summary>
克隆源码:
```shell
git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
```
然后,从项目目录安装:
```shell
pip install .
```
正如在快速安装一节,你可以使用 `BUILD_CUDA_EXT=0` 来取消构建 cuda 拓展。
如果你想要使用 triton 加速且其能够被你的操作系统所支持,请使用 `.[triton]`
对应 AMD GPUs为了从源码安装以支持 RoCm请设置 `ROCM_VERSION` 环境变量。同时通过设置 `PYTORCH_ROCM_ARCH` ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)) 可提升编译速度,例如:对于 MI200 系列设备,该变量可设为 `gfx90a`。例子:
```
ROCM_VERSION=5.6 pip install .
```
对于 RoCm 系统,在从源码安装时额外需要提前安装以下包:`rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev`
</details>
## 快速开始
### 量化和推理
> 警告:这里仅是对 AutoGPTQ 中基本接口的用法展示,只使用了一条文本来量化一个特别小的模型,因此其结果的表现可能不如在大模型上执行量化后预期的那样好。
以下展示了使用 `auto_gptq` 进行量化和推理的最简单用法:
```python
from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
)
]
quantize_config = BaseQuantizeConfig(
bits=4, # 将模型量化为 4-bit 数值类型
group_size=128, # 一般推荐将此参数的值设置为 128
desc_act=False, # 设为 False 可以显著提升推理速度,但是 ppl 可能会轻微地变差
)
# 加载未量化的模型,默认情况下,模型总是会被加载到 CPU 内存中
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
# 量化模型, 样本的数据类型应该为 List[Dict],其中字典的键有且仅有 input_ids 和 attention_mask
model.quantize(examples)
# 保存量化好的模型
model.save_quantized(quantized_model_dir)
# 使用 safetensors 保存量化好的模型
model.save_quantized(quantized_model_dir, use_safetensors=True)
# 将量化好的模型直接上传至 Hugging Face Hub
# 当使用 use_auth_token=True 时, 确保你已经首先使用 huggingface-cli login 进行了登录
# 或者可以使用 use_auth_token="hf_xxxxxxx" 来显式地添加账户认证 token
# (取消下面三行代码的注释来使用该功能)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
# 或者你也可以同时将量化好的模型保存到本地并上传至 Hugging Face Hub
# (取消下面三行代码的注释来使用该功能)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
# 加载量化好的模型到能被识别到的第一块显卡中
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
# 从 Hugging Face Hub 下载量化好的模型并加载到能被识别到的第一块显卡中
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# 使用 model.generate 执行推理
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
# 或者使用 TextGenerationPipeline
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
print(pipeline("auto-gptq is")[0]["generated_text"])
```
参考 [此样例脚本](examples/quantization/quant_with_alpaca.py) 以了解进阶的用法。
### 自定义模型
<details>
<summary>以下展示了如何拓展 `auto_gptq` 以支持 `OPT` 模型,如你所见,这非常简单:</summary>
```python
from auto_gptq.modeling import BaseGPTQForCausalLM
class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
# chained attribute name of transformer layer block
layers_block_name = "model.decoder.layers"
# chained attribute names of other nn modules that in the same level as the transformer layer block
outside_layer_modules = [
"model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",
"model.decoder.project_in", "model.decoder.final_layer_norm"
]
# chained attribute names of linear layers in transformer layer module
# normally, there are four sub lists, for each one the modules in it can be seen as one operation,
# and the order should be the order when they are truly executed, in this case (and usually in most cases),
# they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.out_proj"],
["fc1"],
["fc2"]
]
```
然后, 你就可以像在基本用法一节中展示的那样使用 `OPTGPTQForCausalLM.from_pretrained` 和其他方法。
</details>
### 在下游任务上执行评估
你可以使用在 `auto_gptq.eval_tasks` 中定义的任务来评估量化前后的模型在某个特定下游任务上的表现。
这些预定义的模型支持所有在 [🤗 transformers](https://github.com/huggingface/transformers)和本项目中被实现了的 causal-language-models。
<details>
<summary>以下是使用 `cardiffnlp/tweet_sentiment_multilingual` 数据集在序列分类(文本分类)任务上评估 `EleutherAI/gpt-j-6b` 模型的示例:</summary>
```python
from functools import partial
import datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import SequenceClassificationTask
MODEL = "EleutherAI/gpt-j-6b"
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
ID2LABEL = {
0: "negative",
1: "neutral",
2: "positive"
}
LABELS = list(ID2LABEL.values())
def ds_refactor_fn(samples):
text_data = samples["text"]
label_data = samples["label"]
new_samples = {"prompt": [], "label": []}
for text, label in zip(text_data, label_data):
prompt = TEMPLATE.format(labels=LABELS, text=text)
new_samples["prompt"].append(prompt)
new_samples["label"].append(ID2LABEL[label])
return new_samples
# model = AutoModelForCausalLM.from_pretrained(MODEL).eval().half().to("cuda:0")
model = AutoGPTQForCausalLM.from_pretrained(MODEL, BaseQuantizeConfig())
tokenizer = AutoTokenizer.from_pretrained(MODEL)
task = SequenceClassificationTask(
model=model,
tokenizer=tokenizer,
classes=LABELS,
data_name_or_path=DATASET,
prompt_col_name="prompt",
label_col_name="label",
**{
"num_samples": 1000, # how many samples will be sampled to evaluation
"sample_max_len": 1024, # max tokens for each sample
"block_max_len": 2048, # max tokens for each data block
# function to load dataset, one must only accept data_name_or_path as input
# and return datasets.Dataset
"load_fn": partial(datasets.load_dataset, name="english"),
# function to preprocess dataset, which is used for datasets.Dataset.map,
# must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name]
"preprocess_fn": ds_refactor_fn,
# truncate label when sample's length exceed sample_max_len
"truncate_prompt": False
}
)
# note that max_new_tokens will be automatically specified internally based on given classes
print(task.run())
# self-consistency
print(
task.run(
generation_config=GenerationConfig(
num_beams=3,
num_return_sequences=3,
do_sample=True
)
)
)
```
</details>
## 了解更多
[教程](docs/tutorial) 提供了将 `auto_gptq` 集成到你的项目中的手把手指导和最佳实践准则。
[示例](examples/README.md) 提供了大量示例脚本以将 `auto_gptq` 用于不同领域。
## 支持的模型
> 你可以使用 `model.config.model_type` 来对照下表以检查你正在使用的一个模型是否被 `auto_gptq` 所支持。
>
> 比如, `WizardLM``vicuna``gpt4all` 模型的 `model_type` 皆为 `llama` 因此这些模型皆被 `auto_gptq` 所支持。
| model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
|------------------------------------|--------------|-----------|-----------|---------------|-----------------------------------------------------------------------------------|
| bloom | ✅ | ✅ | ✅ | ✅ | |
| gpt2 | ✅ | ✅ | ✅ | ✅ | |
| gpt_neox | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| gptj | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| llama | ✅ | ✅ | ✅ | ✅ | ✅ |
| moss | ✅ | ✅ | ✅ | ✅ | ✅[要求该分支的 peft](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
| opt | ✅ | ✅ | ✅ | ✅ | |
| gpt_bigcode | ✅ | ✅ | ✅ | ✅ | |
| codegen | ✅ | ✅ | ✅ | ✅ | |
| falcon(RefinedWebModel/RefinedWeb) | ✅ | ✅ | ✅ | ✅ | |
## 支持的评估任务
目前, `auto_gptq` 支持以下评估任务: `LanguageModelingTask`, `SequenceClassificationTask``TextSummarizationTask`;更多的评估任务即将到来!
## 致谢
- 特别感谢 **Elias Frantar** **Saleh Ashkboos** **Torsten Hoefler****Dan Alistarh** 提出 **GPTQ** 算法并开源[代码](https://github.com/IST-DASLab/gptq)。
- 特别感谢 **qwopqwop200** 本项目中涉及到模型量化的代码主要参考自 [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda)。
[![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date)

View file

@ -1,2 +1,5 @@
__version__ = "0.5.0.dev0"
from .modeling import BaseQuantizeConfig
from .modeling import AutoGPTQForCausalLM
from .utils.peft_utils import get_gptq_peft_model
from .utils.exllama_utils import exllama_set_max_input_length

View file

@ -1,4 +1,3 @@
from . import _utils as eval_task_utils
from .language_modeling_task import *
from .sequence_classification_task import *
from .text_summarization_task import *

View file

@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedModel
from ._utils.data_utils import get_dataloader
from ..modeling import BaseGPTQForCausalLM
from ..utils.data_utils import get_dataloader
class BaseTask:

View file

@ -1,7 +1,17 @@
from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
from .auto import *
from .bloom import *
from .gpt2 import *
from .gpt_neox import *
from .gptj import *
from .llama import *
from .moss import *
from .opt import *
from .rw import *
from .gpt_bigcode import *
from .codegen import *
from .baichuan import *
from .internlm import *
from .qwen import *
from .mistral import *
from .mpt import *

File diff suppressed because it is too large Load diff

View file

@ -1,13 +1,36 @@
from packaging.version import parse as parse_version
from torch import device
from transformers import __version__ as transformers_version
from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA = device("cuda:0")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt", "moss"]
if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS = [
"bloom",
"gptj",
"gpt2",
"gpt_neox",
"opt",
"moss",
"gpt_bigcode",
"codegen",
"RefinedWebModel",
"RefinedWeb",
"baichuan",
"internlm",
"qwen",
"mpt",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
__all__ = ["CPU", "CUDA", "SUPPORTED_MODELS"]
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]

View file

@ -1,37 +1,69 @@
from logging import getLogger
from typing import Union, Optional
import accelerate
import torch
import torch.nn as nn
from transformers import AutoConfig
import transformers
from ._const import SUPPORTED_MODELS, CUDA
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear
logger = getLogger(__name__)
def get_device(obj: Union[torch.Tensor, nn.Module]):
if isinstance(obj, torch.Tensor):
return obj.device
return next(obj.parameters()).device
def move_to_device(obj: Union[torch.Tensor, nn.Module], device: torch.device):
if get_device(obj) != device:
obj = obj.to(device)
return obj
def find_layers(module, layers=None, name=''):
if not layers:
layers = [nn.Conv2d, nn.Linear]
if type(module) in layers:
return {name: module}
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]
for layer in layers:
if isinstance(module,layer):
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
def get_module_by_name(model, module_name: str):
def get_module_by_name_prefix(model, module_name: str):
for name, module in model.named_modules():
if name.startswith(module_name):
return module
def make_quant(module, names, bits, groupsize, name='', use_triton=False):
if use_triton:
from ..nn_modules.qlinear_triton import QuantLinear
else:
from ..nn_modules.qlinear import QuantLinear
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def make_quant(
module,
names,
bits,
group_size,
name='',
use_triton: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
use_qigen: bool = False,
use_cuda_fp16: bool = True,
desc_act: bool = False,
trainable: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen)
if isinstance(module, QuantLinear):
return
@ -39,43 +71,323 @@ def make_quant(module, names, bits, groupsize, name='', use_triton=False):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
ori_layer_device = get_device(getattr(module, attr))
delattr(module, attr)
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
if isinstance(tmp,nn.Linear):
in_features = tmp.in_features
out_features = tmp.out_features
elif isinstance(tmp,nn.Conv2d):
in_features = tmp.in_channels
out_features = tmp.out_channels
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
in_features = tmp.weight.shape[0]
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
)
else:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
new_layer.device = ori_layer_device
setattr(module, attr, new_layer.to(ori_layer_device))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1, use_triton=use_triton)
make_quant(
child,
names,
bits,
group_size,
name + '.' + name1 if name != '' else name1,
use_triton=use_triton,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
trainable=trainable,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen
)
def preprocess_checkpoint_qigen(
module,
names,
bits,
group_size,
checkpoint,
name='',
):
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
def pack_model(model, quantizers, bits, group_size, use_triton=False, autotune_warmup: bool = False):
if use_triton:
from ..nn_modules.qlinear_triton import QuantLinear, autotune_warmup_linear
else:
from ..nn_modules.qlinear import QuantLinear
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
if isinstance(module, QuantLinear):
in_features = module.infeatures
out_features = module.outfeatures
zeros = checkpoint[name + '.qzeros']
scales = checkpoint[name + '.scales'].float()
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
del checkpoint[name + '.qzeros']
del checkpoint[name + '.g_idx']
if name + '.bias' in checkpoint:
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
else:
checkpoint[name + '.bias'] = torch.zeros(out_features)
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
if bits == 4:
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
elif bits == 3:
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
elif bits == 2:
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
checkpoint[name + '.qweight'] = qweight
return
for name1, child in module.named_children():
preprocess_checkpoint_qigen(
child,
names,
bits,
group_size,
checkpoint,
name + '.' + name1 if name != '' else name1,
)
def pack_model(
model,
quantizers,
bits,
group_size,
use_triton=False,
use_cuda_fp16=True,
desc_act=False,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=False, disable_exllamav2=True)
if force_layer_back_to_cpu:
model.to(CPU)
model.cpu()
logger.info('Packing model...')
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(model, quantizers, bits, group_size, use_triton=use_triton)
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act, disable_exllama=False, disable_exllamav2=True)
qlayers = find_layers(model, [QuantLinear])
for name in qlayers:
logger.info(name)
quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = layers[name].to(CPU), scale.to(CPU), zero.to(CPU), g_idx.to(CPU)
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
logger.info('Model packed.')
if use_triton and autotune_warmup:
if use_triton and warmup_triton:
logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the hole model."
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
)
autotune_warmup_linear(model.to(CUDA), seqlen=model.seqlen)
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
def check_and_get_model_type(model_dir):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
def check_and_get_model_type(model_dir, trust_remote_code=False):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
__all__ = ["find_layers", "get_module_by_name", "make_quant", "pack_model", "check_and_get_model_type"]
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
"""
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""
device_to_buffers_size = {}
model_uses_exllama = False
for name, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
model_uses_exllama = True
device = submodule.qweight.device
if device not in device_to_buffers_size:
device_to_buffers_size[device] = {
"max_dq_buffer_size": 1,
"max_inner_outer_dim": 1
}
if not use_act_order:
submodule._use_act_order = False
else:
submodule._use_act_order = True
# Disable this heuristic for detecting act_order, but it could be used instead of the config.
"""
if submodule.g_idx is None:
submodule.act_order = False
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
submodule.g_idx = None
submodule.act_order = False
else:
submodule.act_order = True
"""
device_to_buffers_size[device]["max_dq_buffer_size"] = max(device_to_buffers_size[device]["max_dq_buffer_size"], submodule.qweight.numel() * 8)
if use_act_order:
device_to_buffers_size[device]["max_inner_outer_dim"] = max(device_to_buffers_size[device]["max_inner_outer_dim"], submodule.infeatures, submodule.outfeatures)
if model_uses_exllama:
# To be honest this is quite ugly, not proud of this.
from exllama_kernels import prepare_buffers, set_tuning_params
device_to_buffers = {}
if use_act_order:
if max_input_length is None:
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
else:
max_input_len = max_input_length
else:
if max_input_length is not None:
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")
max_input_len = 1
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_len, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
for device, buffers in model.device_to_buffers.items():
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
# The buffers need to have been initialized first before calling make_q4.
for name, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
submodule.post_init()
## exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device,0))
if model_uses_exllamav2:
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
device_tensors = {}
for device, scratch_bytes in fixed_bytes.items():
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
# have persistent buffers, otherwise we will get OOM
model.device_tensors = device_tensors
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
device = submodule.qweight.device
submodule.post_init(temp_dq = model.device_tensors[device])
torch.cuda.empty_cache()
return model
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size, bits: int):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer('bias', torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
__all__ = [
"get_device",
"move_to_device",
"find_layers",
"get_module_by_name_prefix",
"get_module_by_name_suffix",
"make_quant",
"preprocess_checkpoint_qigen",
"pack_model",
"autogptq_post_init",
"check_and_get_model_type",
"simple_dispatch_model",
"make_sure_no_tensor_in_meta_device"
]

View file

@ -1,20 +1,42 @@
from inspect import signature
from typing import Dict, Optional, Union
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM
"moss": MOSSGPTQForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb": RWGPTQForCausalLM,
"falcon": RWGPTQForCausalLM,
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
}
@ -31,31 +53,83 @@ class AutoGPTQForCausalLM:
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
**model_init_kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(pretrained_model_name_or_path)
model_type = check_and_get_model_type(
pretrained_model_name_or_path, trust_remote_code
)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
bf16=bf16,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
**model_init_kwargs
)
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
model_name_or_path: Optional[str],
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = False,
use_triton: bool = False
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
save_dir=save_dir,
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
# A static list of kwargs needed for huggingface_hub
huggingface_kwargs = [
"cache_dir",
"force_download",
"proxies",
"resume_download",
"local_files_only",
"use_auth_token",
"revision",
"subfolder",
"_raise_exceptions_for_missing_entries",
"_commit_hash"
]
# TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
keywords = {
key: kwargs[key]
for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
if key in kwargs
}
return quant_func(
model_name_or_path=model_name_or_path,
device_map=device_map,
max_memory=max_memory,
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
use_triton=use_triton
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
**keywords
)

View file

@ -0,0 +1,16 @@
from ._base import *
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.W_pack"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"]
]
__all__ = ["BaiChuanGPTQForCausalLM"]

View file

@ -2,6 +2,7 @@ from ._base import *
class BloomGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "BloomBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
inside_layer_modules = [

View file

@ -0,0 +1,16 @@
from ._base import *
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "CodeGenBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
["attn.qkv_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"]
]
__all__ = ["CodeGenGPTQForCausalLM"]

View file

@ -0,0 +1,16 @@
from ._base import *
class GPT2GPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPT2Block"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f"]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],
["mlp.c_fc"],
["mlp.c_proj"]
]
__all__ = ["GPT2GPTQForCausalLM"]

View file

@ -0,0 +1,17 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTBigCodeBlock"
layers_block_name = "transformer.h"
outside_layer_modules = [
"transformer.wpe", "transformer.wte", "transformer.ln_f"
]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],
["mlp.c_fc"],
["mlp.c_proj"]
]
__all__ = ["GPTBigCodeGPTQForCausalLM"]

View file

@ -2,6 +2,7 @@ from ._base import *
class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTNeoXLayer"
layers_block_name = "gpt_neox.layers"
outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"]
inside_layer_modules = [

View file

@ -1,7 +1,9 @@
from ._base import *
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTJBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
@ -11,5 +13,7 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_out"]
]
fused_attn_module_type = FusedGPTJAttentionForQuantizedModel
__all__ = ["GPTJGPTQForCausalLM"]

View file

@ -0,0 +1,16 @@
from ._base import *
class InternLMGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "InternLMDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["InternLMGPTQForCausalLM"]

View file

@ -1,7 +1,20 @@
from logging import getLogger
from ._base import *
from ..utils.import_utils import compare_transformers_version
if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None
logger = getLogger(__name__)
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "LlamaDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
@ -11,10 +24,8 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"]
]
@staticmethod
def _resize_attention_mask(attention_mask):
attention_mask = [each.unsqueeze(1) for each in attention_mask]
return attention_mask
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
__all__ = ["LlamaGPTQForCausalLM"]

View file

@ -0,0 +1,16 @@
from ._base import *
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MistralDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["MistralGPTQForCausalLM"]

View file

@ -2,11 +2,15 @@ from ._base import *
class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MossBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.drop", "transformer.ln_f"]
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
["attn.qkv_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"]
]
__all__ = ["MOSSGPTQForCausalLM"]

18
auto_gptq/modeling/mpt.py Normal file
View file

@ -0,0 +1,18 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]
inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]
__all__ = ["MPTGPTQForCausalLM"]

View file

@ -2,6 +2,7 @@ from ._base import *
class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "OPTDecoderLayer"
layers_block_name = "model.decoder.layers"
outside_layer_modules = [
"model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",
@ -14,10 +15,5 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
["fc2"]
]
@staticmethod
def _resize_attention_mask(attention_mask):
attention_mask = [each.unsqueeze(1) for each in attention_mask]
return attention_mask
__all__ = ["OPTGPTQForCausalLM"]

View file

@ -0,0 +1,16 @@
from ._base import *
class QwenGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "QWenBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f", "transformer.visual"]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],
["mlp.w1", "mlp.w2"],
["mlp.c_proj"]
]
__all__ = ["QwenGPTQForCausalLM"]

15
auto_gptq/modeling/rw.py Normal file
View file

@ -0,0 +1,15 @@
from ._base import *
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DecoderLayer"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]
]
__all__ = ["RWGPTQForCausalLM"]

View file

@ -0,0 +1,42 @@
from abc import abstractmethod
from logging import getLogger
import torch.nn as nn
from .triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
class FusedBaseModule(nn.Module, TritonModuleMixin):
@classmethod
@abstractmethod
def inject_to_model(cls, *args, **kwargs):
raise NotImplementedError()
class FusedBaseAttentionModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
**kwargs
):
raise NotImplementedError()
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
pass
class FusedBaseMLPModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
raise NotImplementedError()

View file

@ -0,0 +1,303 @@
from typing import *
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.models.gptj.modeling_gptj import GPTJAttention
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
def __init__(self, config):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.attn_dropout_p = config.attn_pdrop
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
def _split_heads(self, qkv):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = qkv.size()[:-1] + (3, self.num_attention_heads, self.head_dim)
qkv = qkv.view(new_shape) # (batch, seq_length, 3, head, head_features)
query = qkv[:, :, 0]
key = qkv[:, :, 1]
value = qkv[:, :, 2]
return query, key, value
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query, key, value = self._split_heads(self.qkv_proj(hidden_states))
seq_len = key.shape[1]
offset = 0
if layer_past is not None:
offset = layer_past[0].shape[-2]
seq_len += offset
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
is_causal = layer_past is None
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None if is_causal else attention_mask,
dropout_p=self.attn_dropout_p,
is_causal=is_causal
)
attn_weights = None
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
@classmethod
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
continue
attn = cls(config).to(device=next(m.buffers()).device)
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_proj.qweight = qweights
qkv_proj.qzeros = qzeros
qkv_proj.scales = scales
qkv_proj.g_idx = g_idx
qkv_proj.bias = bias
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
attn.qkv_proj = qkv_proj
attn.out_proj = m.out_proj
setattr(parent, child_name, attn)
del m
__all__ = ["FusedGPTJAttentionForQuantizedModel"]

View file

@ -0,0 +1,203 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
rotary_emb,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if self.head_dim * num_heads != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = rotary_emb
def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states,
past_key_value=None,
attention_mask=None,
position_ids=None,
output_attentions=False,
use_cache=False,
**kwargs
):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None if is_causal else attention_mask,
is_causal=is_causal
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
@classmethod
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, attn)
__all__ = ["FusedLlamaAttentionForQuantizedModel"]

View file

@ -0,0 +1,330 @@
import math
from logging import getLogger
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from ._fused_base import FusedBaseMLPModule
from ..utils.import_utils import TRITON_AVAILABLE
logger = getLogger(__name__)
if TRITON_AVAILABLE:
import triton
import triton.language as tl
from .triton_utils import custom_autotune
from .triton_utils.kernels import silu
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 256,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
), # 3090
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 16,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
), # 3090
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=4
), # 3090
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 16,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
), # 3090
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
), # 3090
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def quant_fused_matmul_248_kernel(
a_ptr, c_ptr, b1_ptr,
scales1_ptr, zeros1_ptr,
g1_ptr, b2_ptr,
scales2_ptr, zeros2_ptr,
g2_ptr,
M, N, K,
bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Computes: C = silu(A * B1) * (A * B2)
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (1, N) float16
zeros is of shape (1, N//8) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
g1_ptrs = g1_ptr + offs_k
g2_ptrs = g2_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales1_ptrs = scales1_ptr + offs_bn[None, :]
scales2_ptrs = scales2_ptr + offs_bn[None, :]
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g1_idx = tl.load(g1_ptrs)
g2_idx = tl.load(g2_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
zeros1 = (zeros1 + 1)
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
zeros2 = (zeros2 + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
b2 = tl.load(b2_ptrs)
# Now we need to unpack b (which is N-bit values) into 32-bit values
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
b1 = (b1 - zeros1) * scales1 # Scale and shift
accumulator1 += tl.dot(a, b1)
b2 = (b2 >> shifter[:, None]) & maxq
b2 = (b2 - zeros2) * scales2
accumulator2 += tl.dot(a, b2)
a_ptrs += BLOCK_SIZE_K
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g1_ptrs += BLOCK_SIZE_K
g2_ptrs += BLOCK_SIZE_K
accumulator1 = silu(accumulator1)
c = accumulator1 * accumulator2
c = c.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
else:
quant_fused_matmul_248_kernel = None
class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
):
super().__init__()
self.infeatures = gate_proj.infeatures
self.intermediate_size = gate_proj.outfeatures
self.outfeatures = down_proj.outfeatures
self.bits = gate_proj.bits
self.maxq = gate_proj.maxq
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj
def forward(self, x):
return self.down_proj(self.triton_llama_mlp(x))
def triton_llama_mlp(self, x):
with torch.cuda.device(x.device):
out_shape = x.shape[:-1] + (self.intermediate_size, )
x = x.reshape(-1, x.shape[-1])
M, K = x.shape
N = self.intermediate_size
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
quant_fused_matmul_248_kernel[grid](
x, c, self.gate_proj.qweight,
self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,
self.up_proj.qweight,
self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,
M, N, K,
self.bits, self.maxq,
x.stride(0), x.stride(1),
self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),
c.stride(0), c.stride(1),
self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)
)
c = c.reshape(out_shape)
return c
@classmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
if not use_triton:
logger.warning(f"skip module injection for {cls.__name__} not support integrate without triton yet.")
return
elif not TRITON_AVAILABLE:
logger.warning(f"skip module injection for triton is not installed.")
return
for name, m in model.named_modules():
if not isinstance(m, LlamaMLP):
continue
mlp = cls(m.gate_proj, m.down_proj, m.up_proj)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, mlp)
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.intermediate_size
if (k, n) not in kn_values:
kn_values[(k, n)] = m
logger.info(f'Found {len(kn_values)} unique fused mlp KN values.')
logger.info('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2 ** m
for (k, n), (modules) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
modules.triton_llama_mlp(a)
del kn_values
__all__ = ["FusedLlamaMLPForQuantizedModel"]

View file

@ -0,0 +1,56 @@
import torch.nn as nn
class GeneralQuantLinear(nn.Linear):
def __init__(self, quant_linear_module):
super().__init__(
in_features=quant_linear_module.infeatures,
out_features=quant_linear_module.outfeatures,
bias=True
)
self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures
self.bits = quant_linear_module.bits
self.group_size = quant_linear_module.group_size
self.maxq = quant_linear_module.maxq
self.weight.requires_grad = False
self.weight.data = quant_linear_module.qweight
self.register_buffer('qweight', quant_linear_module.qweight)
self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False
self.bias.requires_grad = False
self.register_buffer('qzeros', quant_linear_module.qzeros)
self.register_buffer('scales', quant_linear_module.scales)
self.register_buffer('g_idx', quant_linear_module.g_idx)
if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf
if hasattr(quant_linear_module, "kernel_switch_threshold"):
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward
@classmethod
def inject_to_model(cls, model, target_module_type):
for name, m in model.named_modules():
if not isinstance(m, target_module_type):
continue
new_m = cls(m)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, new_m)

View file

@ -4,36 +4,45 @@ from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
logger = getLogger(__name__)
try:
import quant_cuda
_quant_cuda_available = True
import autogptq_cuda_256
import autogptq_cuda_64
_autogptq_cuda_available = True
except ImportError:
logger.warning('CUDA extension not installed.')
_quant_cuda_available = False
autogptq_cuda_256 = None
autogptq_cuda_64 = None
_autogptq_cuda_available = False
class QuantLinear(nn.Module):
QUANT_TYPE = "cuda"
def __init__(
self,
bits,
groupsize,
group_size,
infeatures,
outfeatures,
bias,
kernel_switch_threshold=128,
trainable=False
):
super().__init__()
global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else infeatures
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer(
@ -42,15 +51,15 @@ class QuantLinear(nn.Module):
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
@ -59,28 +68,38 @@ class QuantLinear(nn.Module):
# is performed by unpacking the weights and using torch.matmul
if self.bits in [2, 4, 8]:
self.register_buffer(
'wf',
torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0),
persistent=False
)
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
elif self.bits == 3:
self.register_buffer(
'wf',
torch.tensor(
[
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],
],
dtype=torch.int32).reshape(1, 3, 12),
persistent=False
)
self.wf = torch.tensor(
[
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],
],
dtype=torch.int32
).reshape(1, 3, 12)
self.kernel_switch_threshold = kernel_switch_threshold
self.quant_cuda_available = _quant_cuda_available
self.autogptq_cuda_available = _autogptq_cuda_available
self.autogptq_cuda = autogptq_cuda_256
if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda = autogptq_cuda_64
if infeatures % 64 != 0 or outfeatures % 64 != 0:
self.autogptq_cuda_available = False
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
@ -95,7 +114,7 @@ class QuantLinear(nn.Module):
intweight.append(
torch.round(
(
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
@ -177,28 +196,30 @@ class QuantLinear(nn.Module):
def forward(self, x: torch.Tensor):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
if self.quant_cuda_available and (
if self.autogptq_cuda_available and (
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
):
out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32)
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
if self.bits == 2:
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
self.autogptq_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
self.autogptq_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
self.autogptq_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
self.autogptq_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
out = out.half()
else:
if self.wf.device != self.qzeros.device:
self.wf = self.wf.to(self.qzeros.device)
if self.bits in [2, 4, 8]:
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0)
).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)
@ -207,7 +228,7 @@ class QuantLinear(nn.Module):
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1)
).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1)
elif self.bits == 3:
zeros = self.qzeros.reshape(
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
@ -233,12 +254,23 @@ class QuantLinear(nn.Module):
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]))
out = torch.matmul(x.half(), weights)
out = out.reshape(out_shape)
num_itr = self.g_idx.shape[0]//x.shape[-1]
if num_itr == 1:
weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]))
else:
num_dim = self.g_idx.shape[0]//num_itr
weights = []
for i in range(num_itr):
scale_i = self.scales[:,i*num_dim:(i+1)*num_dim]
weight_i = weight[:,i*num_dim:(i+1)*num_dim]
zeros_i = zeros[:,i*num_dim:(i+1)*num_dim]
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x.to(weights.dtype), weights)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
return out.to(x.dtype)
__all__ = ["QuantLinear"]

View file

@ -0,0 +1,275 @@
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
logger = getLogger(__name__)
try:
import autogptq_cuda_256
import autogptq_cuda_64
_autogptq_cuda_available = True
except ImportError:
logger.warning('CUDA extension not installed.')
autogptq_cuda_256 = None
autogptq_cuda_64 = None
_autogptq_cuda_available = False
class QuantLinear(nn.Module):
QUANT_TYPE = "cuda-old"
def __init__(
self,
bits,
group_size,
infeatures,
outfeatures,
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128,
trainable=False
):
super().__init__()
global _autogptq_cuda_available
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
if trainable:
_autogptq_cuda_available = False
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
self.half_indim = self.infeatures // 2
self.use_cuda_fp16 = use_cuda_fp16 if bits != 8 else False
# is performed by unpacking the weights and using torch.matmul
if self.bits in [2, 4, 8]:
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
elif self.bits == 3:
self.wf = torch.tensor(
[
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],
],
dtype=torch.int32
).reshape(1, 3, 12)
self.kernel_switch_threshold = kernel_switch_threshold
self.autogptq_cuda_available = _autogptq_cuda_available
self.autogptq_cuda = autogptq_cuda_256
if infeatures % 256 != 0 or outfeatures % 256 != 0:
self.autogptq_cuda = autogptq_cuda_64
if infeatures % 64 != 0 or outfeatures % 64 != 0:
self.autogptq_cuda_available = False
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
g_idx = idx // self.group_size
intweight.append(
torch.round(
(W[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
if self.autogptq_cuda_available is True and (
self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold
):
out = torch.zeros(x.shape[0], out_shape[-1], dtype=torch.float, device=x.device)
if self.use_cuda_fp16:
x = x.half()
if self.bits == 2:
self.autogptq_cuda.vecquant2matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
elif self.bits == 3:
self.autogptq_cuda.vecquant3matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
elif self.bits == 4:
self.autogptq_cuda.vecquant4matmul_faster_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size, self.half_indim)
else:
raise NotImplementedError("Only 2,3,4 bits are supported.")
else:
x = x.float()
if self.bits == 2:
self.autogptq_cuda.vecquant2matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
elif self.bits == 3:
self.autogptq_cuda.vecquant3matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
elif self.bits == 4:
self.autogptq_cuda.vecquant4matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
elif self.bits == 8:
self.autogptq_cuda.vecquant8matmul_old(x, self.qweight, out, self.scales.float(), self.qzeros, self.group_size)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
else:
if self.wf.device != self.qzeros.device:
self.wf = self.wf.to(self.qzeros.device)
if self.bits in [2,4,8]:
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
scales = self.scales
scales = scales.reshape(-1, 1, scales.shape[-1])
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1)
weight = weight.reshape(-1, self.group_size, weight.shape[2])
elif self.bits == 3:
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
zeros = (zeros >> self.wf.unsqueeze(0))
zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4)
zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6)
zeros = zeros & 0x7
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)
zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
scales = self.scales
scales = scales.reshape(-1, 1, scales.shape[-1])
weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1)
weight = (weight >> self.wf.unsqueeze(-1))&0x7
weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4)
weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6)
weight = weight & 0x7
weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1)
weight = weight.reshape(-1, self.group_size, weight.shape[2])
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
out = torch.matmul(x.to(weight.dtype), weight)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out.to(x.dtype)
__all__ = ["QuantLinear"]

View file

@ -0,0 +1,171 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllama
from logging import getLogger
import torch
import torch.nn as nn
import math
import numpy as np
import transformers
logger = getLogger(__name__)
try:
from exllama_kernels import make_q4, q4_matmul
except ImportError:
logger.error('exllama_kernels not installed.')
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllama"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllama kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
if trainable:
raise NotImplementedError("Exllama kernel does not support training.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.trainable = trainable
self.maxq = 2 ** self.bits - 1
assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.width = self.qweight.shape[1]
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx.to("cpu") if self._use_act_order else None,
self.qweight.device.index
)
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out = ext_q4_matmul(x.half(), self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View file

@ -0,0 +1,188 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from logging import getLogger
import torch
import torch.nn as nn
import math
logger = getLogger(__name__)
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError:
logger.error('exllamav2_kernels not installed.')
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def _torch_device(idx):
if idx == -1: return "cpu"
return f"cuda:{idx}"
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
"""
Create Q matrix
"""
# EXL2
# won't work as the moment because the tensors are not the same.
if "q_weight" in w:
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
return make_q_matrix(w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
temp_dq)
# GPTQ
elif "qweight" in w:
if w["scales"].dtype == torch.float:
w["scales"] = w["scales"].half()
# GPTQ with g_idx (act_order)
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
temp_dq)
# GPTQ without g_idx
else:
return make_q_matrix(w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,
temp_dq)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
if trainable:
raise NotImplementedError("Exllamav2 kernel does not support training.")
self.q_handle = None
self.q_tensors = None
self.padding = - outfeatures % 32
self.infeatures = infeatures
self.outfeatures = outfeatures + self.padding
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.trainable = trainable
self.maxq = 2 ** self.bits - 1
assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
"qweight":self.qweight,
"qzeros":self.qzeros,
"scales":self.scales,
"g_idx":self.g_idx
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq
)
def forward(self, x, force_cuda = False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
scratch_bytes: int
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device_idx, scratch_bytes):
self.device_idx = device_idx
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
def get_scratch_slice(self, size_bytes):
if self.scratch is None: self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
scratch_slice = self.scratch.narrow(0, 0, size_half)
return scratch_slice

View file

@ -0,0 +1,262 @@
from copy import deepcopy
import torch
from torch import nn
from tqdm import tqdm
import gc
import math
import numpy as np
from gekko import GEKKO
from logging import getLogger
logger = getLogger(__name__)
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model
#cinfergen if bits==3:
# tu = tu*3
B = m.Const(value=bits)
TP = m.Const(value=T//p)
k = m.Var(1,integer=True,lb=1)
z = m.Var(1,integer=True,lb=1)
w = m.Var(1,integer=True,lb=1)
y = m.Var(1,integer=True,lb=1)
v = m.Var(1,integer=True,lb=1)
mb = m.Var(mu,integer=True,lb=1)
if gs != -1:
gg = m.Var(1,integer=True,lb=1)
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
L = m.Var(integer=True,lb=0,ub=l1)
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
m.Equation(mb * k == M)
if gs != -1:
m.Equation(gs * gg == mb)
# m.Equation(tb * z == T)
m.Equation(tb * z == TP)
m.Equation(mu * w == mb)
m.Equation(tu * y == tb)
# m.Equation(tb * v == tt)
m.Maximize(L)
m.options.SOLVER = 1
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 2', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
try:
m.solve(disp=False)
except:
try:
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 1', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
m.solve(disp=False)
except:
# mytb = T//p
mytb = tu
if gs != -1:
mymb = gs
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
mymb += gs
while M % mymb != 0:
mymb -= gs
return (int(mymb), int(mytb))
else:
mymb = mu
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
mymb += mu
while M % mymb != 0:
mymb -= mu
return (int(mymb), int(mytb))
return (int(mb.value[0]), int(tb.value[0]))
params = {}
def compute_reductions(x, gs=-1, cpp=True):
if cpp:
if len(x.shape) != 1:
rows, cols = x.shape
else:
rows = 1
cols = x.shape[0]
if gs == -1:
out = torch.zeros(rows).float().contiguous()
mygs = cols
else:
out = torch.zeros(rows, cols // gs).float().contiguous()
mygs = gs
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
return out
if gs == -1:
if len(x.shape) != 1:
return torch.sum(x,1)
else:
return torch.sum(x)
else:
if len(x.shape) != 1:
rows, cols = x.shape
out = torch.zeros(rows, cols // gs).float().contiguous()
for i in range(cols // gs):
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
return out
else:
cols = x.shape[0]
out = torch.zeros(cols // gs).float().contiguous()
for i in range(cols // gs):
out[i] = torch.sum(x[i*gs:(i+1)*gs])
return out
def process_zeros_scales(zeros, scales, bits, M):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != M:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != M:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
class QuantLinear(nn.Module):
QUANT_TYPE = "qigen"
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
super().__init__()
if bits not in [2, 4]:
raise NotImplementedError("Only 2,4 bits are supported.")
if trainable:
raise NotImplementedError("Qigen kernel does not support training.")
self.bits = bits
pack = 32 // bits
self.infeatures = infeatures
self.outfeatures = outfeatures
n = hint
m = self.infeatures
t = self.outfeatures
#registers for now are fixed
if bits == 3:
packed = 32
unroll = 3
nu = 1 #args.n
mu = 32
tu = 32
else:
packed = 32 // bits
unroll = 2
nu = 1 #args.n
mu = 16
tu = 32
nb = n # it's always small for transformers
global params
if (m,t) in params:
mb = params[(m,t)][0]
tb = params[(m,t)][1]
else:
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
params[(m,t)] = (mb,tb)
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert(np.sum(split) == t)
split = split.astype(int)
self.tt = int(split[0])
if split[0] == split[-1]:
self.cutoff = int(p+1)
else:
self.cutoff = int(idx + 1)
self.mb = mb #// packed
self.tb = tb
self.group_size = group_size
self.register_buffer('bias', torch.zeros(self.outfeatures))
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
if bits == 4:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
elif bits == 3:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
elif bits == 2:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
B = x.shape[0]
new_x = x.T.contiguous()
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
if self.group_size == -1:
if self.bits == 4:
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 2:
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 3:
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
else:
if self.bits == 4:
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 2:
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 3:
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
return out.reshape(out_shape)

View file

@ -0,0 +1,186 @@
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
from ..triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
try:
from ..triton_utils.kernels import (
quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248,
QuantLinearFunction, QuantLinearInferenceOnlyFunction
)
except ImportError:
logger.error('triton not installed.')
raise
class QuantLinear(nn.Module, TritonModuleMixin):
QUANT_TYPE = "triton"
def __init__(
self,
bits,
group_size,
infeatures,
outfeatures,
bias,
trainable=False
):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(
W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq
)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (m.qweight, m.scales, m.qzeros, m.g_idx, m.bits, m.maxq)
logger.info(f'Found {len(kn_values)} unique KN Linear values.')
logger.info('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2 ** m
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
if transpose:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
else:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values
__all__ = ["QuantLinear"]

View file

@ -1,485 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
from logging import getLogger
logger = getLogger(__name__)
try:
import triton
import triton.language as tl
from .triton_utils import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=2,
num_warps=8
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
num_stages=3,
num_warps=8
),
triton.Config(
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
num_stages=2,
num_warps=4
),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K,
bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(configs=[
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
num_stages=4,
num_warps=4
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8},
num_stages=2,
num_warps=8
),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8},
num_stages=3,
num_warps=8
),
triton.Config(
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=2,
num_warps=4
),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True
)
@triton.jit
def transpose_matmul_248_kernel(
a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K,
bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except ImportError:
logger.warning('triton not installed.')
raise
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
)
matmul_248_kernel[grid](
input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1],
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)
transpose_matmul_248_kernel[grid](
input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], output_dim,
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class QuantLinear(nn.Module):
def __init__(
self,
bits,
groupsize,
infeatures,
outfeatures,
bias
):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer(
'qweight',
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq
)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
def autotune_warmup_linear(model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, QuantLinear):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
logger.info(f'Found {len(kn_values)} unique KN Linear values.')
logger.info('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2 ** m
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
if transpose:
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values
__all__ = [
"QuantLinear",
"autotune_warmup_linear"
]

View file

@ -0,0 +1,402 @@
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from logging import getLogger
import triton
import triton.language as tl
from . import custom_autotune
logger = getLogger(__name__)
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8
)
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def quant_matmul_248_kernel(
a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K,
bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8
)
],
key=['M', 'N', 'K'],
nearest_power_of_two=True
)
@triton.jit
def transpose_quant_matmul_248_kernel(
a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K,
bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
)
quant_matmul_248_kernel[grid](
input, qweight, output,
scales.to(input.dtype), qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1],
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)
transpose_quant_matmul_248_kernel[grid](
input, qweight, output,
scales.to(input.dtype), qzeros, g_idx,
input.shape[0], qweight.shape[1], output_dim,
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),
)
quant_matmul_248_kernel[grid](
input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1],
bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)
)
return output
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output

View file

@ -0,0 +1,4 @@
class TritonModuleMixin:
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
pass

View file

@ -23,7 +23,7 @@ class GPTQ:
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
@ -60,7 +60,7 @@ class GPTQ:
self.H += inp.matmul(inp.t())
def fasterquant(
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
@ -80,10 +80,26 @@ class GPTQ:
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
@ -96,11 +112,6 @@ class GPTQ:
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
@ -115,15 +126,21 @@ class GPTQ:
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2
@ -147,17 +164,19 @@ class GPTQ:
logger.info(f'duration: {(time.time() - tick)}')
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))

View file

@ -0,0 +1 @@
from .perplexity_utils import Perplexity

View file

@ -0,0 +1,48 @@
import gc
import torch
def exllama_set_max_input_length(model, max_input_length: int):
"""
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
"""
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
if not model.quantize_config.desc_act:
raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
device_to_buffers_size = {}
for device, buffers in model.device_to_buffers.items():
device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
for key in list(model.device_to_buffers.keys()):
del model.device_to_buffers[key]
model.device_to_buffers = None
del model.device_to_buffers
gc.collect()
torch.cuda.empty_cache()
cleanup_buffers_cuda()
device_to_buffers = {}
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
return model

View file

@ -0,0 +1,86 @@
from packaging.version import parse as parse_version
from logging import getLogger
import torch
try:
import triton
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
try:
import autogptq_cuda_256
import autogptq_cuda_64
AUTOGPTQ_CUDA_AVAILABLE = True
except:
AUTOGPTQ_CUDA_AVAILABLE = False
try:
import exllama_kernels
EXLLAMA_KERNELS_AVAILABLE = True
except:
EXLLAMA_KERNELS_AVAILABLE = False
try:
import exllamav2_kernels
EXLLAMAV2_KERNELS_AVAILABLE = True
except:
EXLLAMAV2_KERNELS_AVAILABLE = False
try:
import cQIGen as qinfer
QIGEN_AVAILABLE = True
except:
QIGEN_AVAILABLE = False
logger = getLogger(__name__)
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
if use_qigen:
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
else:
if use_triton:
if torch.version.hip:
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
else:
if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
elif not desc_act or group_size == -1:
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
else:
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
return QuantLinear
def compare_transformers_version(
version: str = "v4.28.0",
op: str = "eq"
):
assert op in ["eq", "lt", "le", "gt", "ge"]
from transformers import __version__
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
def compare_pytorch_version(
version: str = "v2.0.0",
op: str = "eq"
):
assert op in ["eq", "lt", "le", "gt", "ge"]
from torch import __version__
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))

View file

@ -0,0 +1,423 @@
import warnings
import re
from contextlib import contextmanager
from dataclasses import asdict
from enum import Enum
from typing import List, Optional
import torch
from peft import get_peft_model, PeftConfig, PeftModel, PeftType
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding
from peft.tuners.adalora import AdaLoraConfig, AdaLoraLayer, AdaLoraModel
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from peft.utils.other import _get_submodules
from ..modeling._base import BaseGPTQForCausalLM
class GPTQLoraConfig(LoraConfig):
injected_fused_attention: bool = False
injected_fused_mlp: bool = False
class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
def __init__(
self,
adapter_name: str,
linear_module: torch.nn.Linear,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
LoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
self.linear_module = linear_module
self.weight.requires_grad = False
self.weight = self.linear_module.weight
self.bias = self.linear_module.bias
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
def merge(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
def forward(self, x: torch.Tensor):
previous_dtype = x.dtype
if self.active_adapter not in self.lora_A.keys():
return self.linear_module(x)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = self.linear_module(x)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = self.linear_module(x)
lora_B = self.lora_B[self.active_adapter]
lora_A = self.lora_A[self.active_adapter]
lora_dropout = self.lora_dropout[self.active_adapter]
scale = self.scaling[self.active_adapter]
x = x.type_as(lora_A.weight.data)
adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result)
result += adapter_result
else:
result = self.linear_module(x)
result = result.to(previous_dtype)
return result
class GPTQLoraModel(LoraModel):
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = False
if hasattr(target, "bias"):
bias = target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if isinstance(target, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
else:
if isinstance(target, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and its subclasses are supported."
)
new_module = GPTQLoraLinear(adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
if not isinstance(new_module, GPTQLoraLinear):
new_module.weight = old_module.weight
if hasattr(old_module, "bias"):
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def merge_adapter(self):
raise NotImplementedError("gptq model not support merge ada lora adapter")
def unmerge_adapter(self):
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
def merge_and_unload(self):
raise NotImplementedError("gptq model not support merge and unload")
class GPTQAdaLoraConfig(AdaLoraConfig):
injected_fused_attention: bool = False
injected_fused_mlp: bool = False
class GPTQSVDLinear(torch.nn.Linear, AdaLoraLayer):
def __init__(
self,
adapter_name: str,
linear_module: torch.nn.Linear,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
AdaLoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
self.linear_module = linear_module
self.weight.requires_grad = False
self.weight = self.linear_module.weight
self.bias = self.linear_module.bias
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def merge(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
def forward(self, x: torch.Tensor):
if self.active_adapter not in self.lora_A.keys():
return self.linear_module(x)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = self.linear_module(x)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = self.linear_module(x)
result += (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
result = self.linear_module(x)
return result
class GPTQAdaLoraModel(AdaLoraModel):
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.init_r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.init_r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and its subclasses are supported."
)
new_module = GPTQSVDLinear(adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def merge_adapter(self):
raise NotImplementedError("gptq model not support merge ada lora adapter")
def unmerge_adapter(self):
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
def merge_and_unload(self):
raise NotImplementedError("gptq model not support merge and unload")
def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True):
if not ignore:
ignore = []
lm_head_name = model.lm_head_name
if ignore_lm_head and lm_head_name not in ignore:
ignore.append(lm_head_name)
results = set()
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
res = n.split('.')[-1]
if res not in ignore:
results.add(res)
return list(results)
@contextmanager
def hijack_peft_mappings():
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
try:
yield
except:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
raise
finally:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
def get_gptq_peft_model(
model: BaseGPTQForCausalLM,
peft_config: PeftConfig = None,
model_id: str = None,
adapter_name: str = "default",
auto_find_all_linears: bool = True,
train_mode: bool = False
):
if train_mode and not model.trainable:
model.enable_trainable_mode()
if train_mode and not peft_config:
raise ValueError("peft_config not specified when in train mode.")
if not train_mode and not model_id:
raise ValueError("model_id(where to load adapters) not specified when in inference mode.")
if model.fused_attn_module_type is not None and not model.injected_fused_attention:
peft_types = [PeftType.LORA.value, PeftType.ADALORA.value]
warnings.warn(
f"You can just ignore this warning if the peft type you use isn't in {peft_types}.\n"
f"{model.__class__.__name__} supports injecting fused attention but not enables this time. "
"If you are training adapters, you must also disable fused attention injection when loading quantized "
"base model at inference time, otherwise adapters may not be added to base model properly. "
"If you are loading adapters to do inference, you can reference to adapter's config file to check "
"whether the adapters are trained using base model that not enable fused attention injection."
)
if model.injected_fused_mlp:
raise NotImplementedError("GPTQ model that enables fused mlp injection is not supported to integrate with peft.")
if train_mode:
peft_type = peft_config.peft_type
if not isinstance(peft_type, str):
peft_type = peft_type.value
if peft_type in [PeftType.LORA.value, PeftType.ADALORA.value]:
if auto_find_all_linears:
peft_config.target_modules = find_all_linear_names(model, ignore_lm_head=True)
if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
peft_config = GPTQLoraConfig(**peft_config.to_dict())
if peft_type == PeftType.ADALORA.value and not isinstance(peft_config, GPTQAdaLoraConfig):
peft_config = GPTQAdaLoraConfig(**peft_config.to_dict())
peft_config.injected_fused_attention = model.injected_fused_attention
peft_config.injected_fused_mlp = model.injected_fused_mlp
if peft_type == PeftType.ADAPTION_PROMPT.value:
if peft_config.adapter_layers > model.config.num_hidden_layers:
warnings.warn(
f"model has only {model.config.num_hidden_layers} layers "
f"but adapter_layers is set to {peft_config.adapter_layers}, "
f"will reset value to {model.config.num_hidden_layers}."
)
peft_config.adapter_layers = model.config.num_hidden_layers
if model.injected_fused_attention:
raise NotImplementedError(
"model with fused attention injected isn't supported to use ADAPTION_PROMPT peft type yet."
)
with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
raise NotImplementedError(
f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
)
return peft_model
__all__ = [
"GPTQLoraConfig",
"GPTQLoraModel",
"GPTQAdaLoraConfig",
"GPTQAdaLoraModel",
"find_all_linear_names",
"get_gptq_peft_model"
]

View file

@ -0,0 +1,215 @@
import sys
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
class Perplexity:
"""
A class for calculating the perplexity of a language model.
"""
def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
"""
Calculate perplexity using the same method as seen in llama.cpp.
Parameters
----------
model : AutoModelForCausalLM
The language model for which the perplexity is calculated.
tokenizer : AutoTokenizer
The tokenizer corresponding to the model.
device : str, optional
The device to run the calculations on. If auto, the device that your model uses
will be the device used for these calculations. Default is 'auto'.
dataset_path : str, optional
The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
dataset_name : str, optional
The name of the dataset. Default is None.
split : str, optional
The split of the dataset to use. Default is 'test'.
text_column : str, optional
The name of the column in the dataset that contains the text data. Default is 'text'.
"""
self._model = model
self._tokenizer = tokenizer
self._dataset_path = dataset_path
self._dataset_name = dataset_name
self._split = split
self._text_column = text_column
self._text = self._prepare_data()
def _get_device(self):
if torch.backends.mps.is_available():
return 'mps'
elif torch.cuda.is_available():
return 'cuda:0'
else:
return 'cpu'
def _prepare_data(self):
"""
Prepares the dataset by loading and formatting.
Returns
-------
str
The formatted dataset as a single string.
"""
if self._dataset_path == 'wikitext':
self._dataset_name = 'wikitext-2-raw-v1'
# Load the dataset
data = load_dataset(self._dataset_path, self._dataset_name, split=self._split)
# Format the text column of the dataset
text_list = [' \n' if s == '' else s for s in data[self._text_column]]
return ''.join(text_list)
@staticmethod
def softmax(logits):
"""
Static method for applying the softmax function.
Parameters
----------
logits : np.ndarray
The input to the softmax function.
Returns
-------
np.ndarray
The output of the softmax function.
"""
e_x = np.exp(logits - np.max(logits))
return e_x / e_x.sum(axis=0)
def calculate_perplexity(self, n_ctx=512, n_batch=512):
"""
Calculates the perplexity of the language model.
Parameters
----------
n_ctx : int
The context size.
n_batch : int
The batch size.
Returns
-------
list
The list of perplexity scores calculated.
"""
# Tokenize the text
self._tokenizer.model_max_length = sys.maxsize
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
nll = 0.0 # Negative log likelihood
count = 0 # Counter for processed tokens
curr_ppl = 0
all_perplexity = []
with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress:
for i in progress:
# Process each batch of tokens
nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count)
# Calculate and display the current perplexity
curr_ppl = np.exp(nll / count)
all_perplexity.append(curr_ppl)
progress.set_description(f"Perplexity: {curr_ppl:.4f}")
return all_perplexity
def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):
"""
Processes each batch of tokens.
Parameters
----------
i : int
The batch index.
n_ctx : int
The context size.
n_batch : int
The batch size.
tokens : torch.Tensor
The tokenized text.
nll : float
The current negative log likelihood.
count : int
The current count of processed tokens.
Returns
-------
float
The updated negative log likelihood.
int
The updated count of processed tokens.
"""
start = i * n_ctx
end = start + n_ctx
num_batches = (n_ctx + n_batch - 1) // n_batch
logits = []
for j in range(num_batches):
batch_start = start + j * n_batch
batch_size = min(end - batch_start, n_batch)
token_org = tokens[0][batch_start].item()
if j == 0:
# Replace the first token with the BOS token
tokens[0][batch_start] = self._tokenizer.bos_token_id
# Compute the logits for the current batch of tokens
batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)
tokens[0][batch_start] = token_org
logits.append(batch_logits)
# We rely on the fact that attention in the forward pass only looks at previous
# tokens here, so the logits returned for each token are an accurate representation
# of what the model would have predicted at that point.
#
# Example, we have a context window of 512, we will compute perplexity for each of the
# last 256 tokens. Then, we split the input up into context window size chunks to
# process the entire prompt.
for j in range(min(512, n_ctx // 2), n_ctx - 1):
tok_logits = logits[0][0][j].cpu().numpy()
# Compute the probability of the next token
prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]
# Update the negative log likelihood and the count of processed tokens
nll += -np.log(prob, where=prob>0)
count += 1
return nll, count
def _compute_batch_logits(self, tokens, batch_start, batch_size):
"""
Computes the logits for a batch of tokens.
Parameters
----------
tokens : torch.Tensor
The tokenized text.
batch_start : int
The start index of the batch.
batch_size : int
The size of the batch.
Returns
-------
torch.Tensor
The logits for the batch of tokens.
"""
# Compute the logits without keeping track of gradients
with torch.no_grad():
outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
return outputs.logits.detach()

View file

@ -0,0 +1,187 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant2matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant4matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant4matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant8matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant8matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
// old
void vecquant2matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant2matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda_old(vec, mat, mul, scales, zeros,groupsize);
}
void vecquant3matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant3matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant4matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant4matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant8matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant8matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant2matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant2matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
void vecquant3matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant3matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
void vecquant4matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant4matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant2matmul_old", &vecquant2matmul_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul_old", &vecquant3matmul_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul_old", &vecquant4matmul_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul_old", &vecquant8matmul_old, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant2matmul_faster_old", &vecquant2matmul_faster_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA), faster version");
m.def("vecquant3matmul_faster_old", &vecquant3matmul_faster_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version");
m.def("vecquant4matmul_faster_old", &vecquant4matmul_faster_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version");
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,187 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant2matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant4matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant4matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
void vecquant8matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);
void vecquant8matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}
// old
void vecquant2matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant2matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda_old(vec, mat, mul, scales, zeros,groupsize);
}
void vecquant3matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant3matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant4matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant4matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant8matmul_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
);
void vecquant8matmul_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda_old(vec, mat, mul, scales, zeros, groupsize);
}
void vecquant2matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant2matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
void vecquant3matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant3matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
void vecquant4matmul_faster_cuda_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
);
void vecquant4matmul_faster_old(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
int groupsize, int vec_height
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_faster_cuda_old(vec, mat, mul, scales, zeros, groupsize, vec_height);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
m.def("vecquant2matmul_old", &vecquant2matmul_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul_old", &vecquant3matmul_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul_old", &vecquant4matmul_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul_old", &vecquant8matmul_old, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant2matmul_faster_old", &vecquant2matmul_faster_old, "Vector 2-bit Quantized Matrix Multiplication (CUDA), faster version");
m.def("vecquant3matmul_faster_old", &vecquant3matmul_faster_old, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version");
m.def("vecquant4matmul_faster_old", &vecquant4matmul_faster_old, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version");
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View file

@ -0,0 +1,75 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View file

@ -0,0 +1,55 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View file

@ -0,0 +1,63 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "../util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View file

@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View file

@ -0,0 +1,260 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}

View file

@ -0,0 +1,43 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "../tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View file

@ -0,0 +1,225 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
#include "../matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
height / 8,
1
);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View file

@ -0,0 +1,53 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View file

@ -0,0 +1,255 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
}

View file

@ -0,0 +1,49 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View file

@ -0,0 +1,294 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View file

@ -0,0 +1,13 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View file

@ -0,0 +1,33 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View file

@ -0,0 +1,13 @@
#ifndef _config_h
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define QMODE_2BIT 1
#define QMODE_3BIT 1
#define QMODE_4BIT 1
#define QMODE_5BIT 1
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#endif

View file

@ -0,0 +1,12 @@
#ifndef _util_h
#define _util_h
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#endif

View file

@ -0,0 +1,56 @@
#ifndef _compat_cuh
#define _compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View file

@ -0,0 +1,121 @@
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "quant/qdq_util.cuh"
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
#endif

View file

@ -0,0 +1,238 @@
#include "q_gemm.cuh"
#include "util.cuh"
#include "matrix_view.cuh"
#include "../config.h"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
void gemm_half_q_half_cuda_part
(
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
bool clear
)
{
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
kernel<<<gridDim, blockDim>>>
(
a,
b->cuda_q_weight,
b->cuda_q_scale,
b->cuda_q_scale_max,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_8,
b->rows_6,
b->rows_5,
b->rows_4,
b->rows_3,
b->rows_2,
clear
);
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGI(b->rows_4);
// DBGI(b->height);
kernel<<<gridDim, blockDim>>>
(
a,
b->cuda_q_weight,
b->cuda_gptq_qzeros,
b->cuda_gptq_scales,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_4,
clear
);
}
}
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear,
half* temp_dq,
bool force_cuda
)
{
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq;
b->reconstruct(temp_dq);
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
const half alpha = __float2half(1.0f);
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
cublasHgemm(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
size_n, size_m, size_k,
&alpha, temp_dq, size_n,
a, size_k,
&beta, c, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasGemmEx(cublas_handle,
// CUBLAS_OP_N, CUBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n,
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
}
else
{
//printf("cuda\n");
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
}
}
}
__global__ void clear_kernel
(
half* __restrict__ c,
const int size_m,
const int size_n
)
{
int m = blockIdx.y;
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
if (n >= size_n) return;
int4* c_ptr = (int4*)(c + m * size_n + n);
*c_ptr = {};
}
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
)
{
return;
dim3 blockDim, gridDim;
blockDim.x = CLEAR_N_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.y = size_m;
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
}

View file

@ -0,0 +1,33 @@
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q_matrix.cuh"
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
);
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
);
#endif

View file

@ -0,0 +1,484 @@
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const int,
const int,
const int,
const int,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
block_a_ptr[t] = a0;
}
}
// Clear
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
qscales[0]++;
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
}
// a, b offset
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
// Column result
float block_c[m_count][4] = {};
// Dequantize groups
int k = offset_k;
while (k < rows_8 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[5];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
// Accumulate column sums in c
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
return NULL;
}

View file

@ -0,0 +1,219 @@
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hadd2(result, g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
// __syncthreads();
// Column result
float block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}

View file

@ -0,0 +1,603 @@
#include "q_matrix.cuh"
#include "matrix_view.cuh"
#include "util.cuh"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
// QMatrix constructor
QMatrix::QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
) :
device(_device),
height(_height),
width(_width),
groups(_groups),
temp_dq(_temp_dq)
{
cudaSetDevice(device);
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
// Create group map
rows_8 = 0;
rows_6 = 0;
rows_5 = 0;
rows_4 = 0;
rows_3 = 0;
rows_2 = 0;
if (!is_gptq)
{
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
}
free(cpu_q_groups);
rows_6 += rows_8;
rows_5 += rows_6;
rows_4 += rows_5;
rows_3 += rows_4;
rows_2 += rows_3;
}
else
{
rows_4 = height;
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
}
// Shuffle quantized data
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
// Reconstruct b[k,n] (GPTQ)
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_4
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ uint16_t perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
// Reconstruct b[k,n]
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
// Preload remapping table
int t = threadIdx.x;
__shared__ uint16_t perm[BLOCK_KN_SIZE];
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
// Column
int n = offset_n + t;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
int lk = 0;
__syncthreads();
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
dequant_8bit_8(q_0, q_1, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
uint32_t q_3 = *b_ptr; b_ptr += size_n;
uint32_t q_4 = *b_ptr; b_ptr += size_n;
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_4bit_8(q_0, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_2bit_16(q_0, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
}
void QMatrix::reconstruct(half* out)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
reconstruct_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
height,
width,
groupsize,
groups,
out,
rows_8,
rows_6,
rows_5,
rows_4,
rows_3,
rows_2
);
}
else
{
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_gptq_qzeros,
cuda_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
groups,
out,
rows_4
);
}
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint16_t* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Reduce to uint16_t
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
// Move to CUDA
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_new_qweight,
cuda_q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}

View file

@ -0,0 +1,71 @@
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class QMatrix
{
public:
int device;
bool is_gptq;
int height;
int width;
int groups;
int groupsize;
int rows_8;
int rows_6;
int rows_5;
int rows_4;
int rows_3;
int rows_2;
uint32_t* cuda_q_weight = NULL;
uint16_t* cuda_q_perm = NULL;
uint16_t* cuda_q_invperm = NULL;
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
half* temp_dq;
QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
);
~QMatrix();
void reconstruct(half* out);
void make_sequential(const uint32_t* cpu_g_idx);
private:
};
#endif

View file

@ -0,0 +1,103 @@
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z4 = __halves2half2(z4_, z4_);
const half2 z16 = __halves2half2(z16_, z16_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
#else
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,169 @@
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,227 @@
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
#endif
#endif

View file

@ -0,0 +1,207 @@
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
uint32_t qd = q[3 * stride];
uint32_t qe = q[4 * stride];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t qf = qe >> 22;
qe <<= 8;
qe |= qd >> 24;
qd <<= 6;
qd |= qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
uint32_t zd = 0;
uint32_t ze = 0;
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za |= ((qf & 0x001) >> 0) << 15;
zb |= ((qf & 0x002) >> 1) << 15;
zc |= ((qf & 0x004) >> 2) << 15;
zd |= ((qf & 0x008) >> 3) << 15;
ze |= ((qf & 0x010) >> 4) << 15;
za |= ((qf & 0x020) >> 5) << 31;
zb |= ((qf & 0x040) >> 6) << 31;
zc |= ((qf & 0x080) >> 7) << 31;
zd |= ((qf & 0x100) >> 8) << 31;
ze |= ((qf & 0x200) >> 9) << 31;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
q[3 * stride] = zd;
q[4 * stride] = ze;
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y32_ = __float2half_rn(1.0f / 32.0f);
const half2 y32 = __halves2half2(y32_, y32_);
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z32 = __halves2half2(z32_, z32_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
uint32_t qd = q_3;
uint32_t qe = q_4;
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
qa >>= 10;
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
qa >>= 5;
qa &= 0x00010001;
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
qb >>= 10;
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
qb >>= 4;
qb &= 0x00020002;
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
qc >>= 10;
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
qc >>= 3;
qc &= 0x00040004;
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
qd >>= 10;
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
qd >>= 2;
qd &= 0x00080008;
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
qe >>= 10;
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
qe >>= 1;
qe &= 0x00100010;
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hadd2( q3.as_half2, z1);
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hadd2( q6.as_half2, z1);
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
dq[ 8] = __hadd2( q8.as_half2, z1);
dq[ 9] = __hadd2( q9.as_half2, z1);
dq[10] = __hfma2(q10.as_half2, y32, z32);
dq[11] = __hadd2(q11.as_half2, z1);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y32, z32);
dq[14] = __hadd2(q14.as_half2, z1);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,44 @@
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_6bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_6bit_16
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,38 @@
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif

View file

@ -0,0 +1,51 @@
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
#endif

View file

@ -0,0 +1,32 @@
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
{
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
qs_h = __hmul(qs_h, qs_h);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}

View file

@ -0,0 +1,134 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t make_q_matrix
(
torch::Tensor q_weight,
torch::Tensor q_perm,
torch::Tensor q_invperm,
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
torch::Tensor temp_dq
)
{
TORCH_CHECK_DTYPE(q_weight, kInt);
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
int device = q_weight.device().index();
int width = q_weight.size(1);
int groups;
int height;
if (!q_scale.device().is_meta())
{
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
groups = q_scale.size(0);
height = q_invperm.size(0);
}
else
{
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
groups = gptq_qzeros.size(0);
height = q_weight.size(0) * 8;
}
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
QMatrix* m = new QMatrix
(
device,
height,
width,
groups,
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
return reinterpret_cast<uintptr_t> (m);
}
void gemm_half_q_half
(
torch::Tensor a,
uintptr_t b,
torch::Tensor c,
bool force_cuda
)
{
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
TORCH_CHECK_DTYPE(a, kHalf);
TORCH_CHECK_DTYPE(c, kHalf);
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
qm,
(half*) c.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
true,
NULL,
force_cuda
);
}
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,480 @@
#include<omp.h>
#include<immintrin.h>
#include<fstream>
#define mymin(a,b) ((a)<(b)?(a):(b))
#define mymax(a,b) ((a)>(b)?(a):(b))
inline
void q2gemm_gs(const float* __restrict__ input,
const int* __restrict__ W,
const float* __restrict__ scales,
const float* __restrict__ zeros,
const float* __restrict__ bias,
const float* __restrict__ sums,
float* __restrict__ output,
const int n,
const int m,
const int t,
const int nb,
const int mb,
const int tb,
int ogtt,
const int gs,
const int cutoff){
#pragma omp parallel num_threads(8)
{
int tid;
const int mu = 16;
const int nu = 1;
const int tu = 32;
const int on = n / nb;
const int om = m / mb;
const __m256i mask = _mm256_set1_epi32(3);
tid = omp_get_thread_num();
int tt = ogtt;
if(tid >= cutoff){
tt -= tb;
}
const int base_output = tid >= cutoff ?
(tid-cutoff)*tt + (tt+tb)*cutoff:
tid*tt;
const int base_W = tid >= cutoff ?
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16:
tid*tt*m/16;
for(int j = 0; j < tt; j+=tb){
for(int i = 0; i < on; i++) {
for(int k = 0; k < om; k++) {
for(int i1 = 0; i1 < nb; i1+=nu) {
int j1 = 0;
for(; j1 < tb-tu+1; j1+=tu) {
for(int k1 = 0; k1 < mb; k1+=gs) {
__m256 acc0_0 = _mm256_setzero_ps();
__m256 acc0_8 = _mm256_setzero_ps();
__m256 acc0_16 = _mm256_setzero_ps();
__m256 acc0_24 = _mm256_setzero_ps();
for(int k2 = k1; k2 < k1+gs; k2+=16)
{
__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]);
__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]);
__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]);
__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]);
__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]);
__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]);
__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]);
__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]);
__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]);
__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]);
__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]);
__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]);
__m256i ws0_8 = _mm256_srli_epi32(w0, 16);
__m256i ws8_8 = _mm256_srli_epi32(w8, 16);
__m256i ws16_8 = _mm256_srli_epi32(w16, 16);
__m256i ws24_8 = _mm256_srli_epi32(w24, 16);
__m256i wsa0_8= _mm256_and_si256(ws0_8, mask);
__m256i wsa8_8= _mm256_and_si256(ws8_8, mask);
__m256i wsa16_8= _mm256_and_si256(ws16_8, mask);
__m256i wsa24_8= _mm256_and_si256(ws24_8, mask);
__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8);
__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8);
__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8);
__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8);
acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24);
__m256i ws0_9 = _mm256_srli_epi32(w0, 18);
__m256i ws8_9 = _mm256_srli_epi32(w8, 18);
__m256i ws16_9 = _mm256_srli_epi32(w16, 18);
__m256i ws24_9 = _mm256_srli_epi32(w24, 18);
__m256i wsa0_9= _mm256_and_si256(ws0_9, mask);
__m256i wsa8_9= _mm256_and_si256(ws8_9, mask);
__m256i wsa16_9= _mm256_and_si256(ws16_9, mask);
__m256i wsa24_9= _mm256_and_si256(ws24_9, mask);
__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9);
__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9);
__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9);
__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9);
acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24);
__m256i ws0_10 = _mm256_srli_epi32(w0, 20);
__m256i ws8_10 = _mm256_srli_epi32(w8, 20);
__m256i ws16_10 = _mm256_srli_epi32(w16, 20);
__m256i ws24_10 = _mm256_srli_epi32(w24, 20);
__m256i wsa0_10= _mm256_and_si256(ws0_10, mask);
__m256i wsa8_10= _mm256_and_si256(ws8_10, mask);
__m256i wsa16_10= _mm256_and_si256(ws16_10, mask);
__m256i wsa24_10= _mm256_and_si256(ws24_10, mask);
__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10);
__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10);
__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10);
__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10);
acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24);
__m256i ws0_11 = _mm256_srli_epi32(w0, 22);
__m256i ws8_11 = _mm256_srli_epi32(w8, 22);
__m256i ws16_11 = _mm256_srli_epi32(w16, 22);
__m256i ws24_11 = _mm256_srli_epi32(w24, 22);
__m256i wsa0_11= _mm256_and_si256(ws0_11, mask);
__m256i wsa8_11= _mm256_and_si256(ws8_11, mask);
__m256i wsa16_11= _mm256_and_si256(ws16_11, mask);
__m256i wsa24_11= _mm256_and_si256(ws24_11, mask);
__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11);
__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11);
__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11);
__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11);
acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24);
__m256i ws0_12 = _mm256_srli_epi32(w0, 24);
__m256i ws8_12 = _mm256_srli_epi32(w8, 24);
__m256i ws16_12 = _mm256_srli_epi32(w16, 24);
__m256i ws24_12 = _mm256_srli_epi32(w24, 24);
__m256i wsa0_12= _mm256_and_si256(ws0_12, mask);
__m256i wsa8_12= _mm256_and_si256(ws8_12, mask);
__m256i wsa16_12= _mm256_and_si256(ws16_12, mask);
__m256i wsa24_12= _mm256_and_si256(ws24_12, mask);
__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12);
__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12);
__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12);
__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12);
acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24);
__m256i ws0_13 = _mm256_srli_epi32(w0, 26);
__m256i ws8_13 = _mm256_srli_epi32(w8, 26);
__m256i ws16_13 = _mm256_srli_epi32(w16, 26);
__m256i ws24_13 = _mm256_srli_epi32(w24, 26);
__m256i wsa0_13= _mm256_and_si256(ws0_13, mask);
__m256i wsa8_13= _mm256_and_si256(ws8_13, mask);
__m256i wsa16_13= _mm256_and_si256(ws16_13, mask);
__m256i wsa24_13= _mm256_and_si256(ws24_13, mask);
__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13);
__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13);
__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13);
__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13);
acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24);
__m256i ws0_14 = _mm256_srli_epi32(w0, 28);
__m256i ws8_14 = _mm256_srli_epi32(w8, 28);
__m256i ws16_14 = _mm256_srli_epi32(w16, 28);
__m256i ws24_14 = _mm256_srli_epi32(w24, 28);
__m256i wsa0_14= _mm256_and_si256(ws0_14, mask);
__m256i wsa8_14= _mm256_and_si256(ws8_14, mask);
__m256i wsa16_14= _mm256_and_si256(ws16_14, mask);
__m256i wsa24_14= _mm256_and_si256(ws24_14, mask);
__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14);
__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14);
__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14);
__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14);
acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24);
__m256i ws0_15 = _mm256_srli_epi32(w0, 30);
__m256i ws8_15 = _mm256_srli_epi32(w8, 30);
__m256i ws16_15 = _mm256_srli_epi32(w16, 30);
__m256i ws24_15 = _mm256_srli_epi32(w24, 30);
__m256i wsa0_15= _mm256_and_si256(ws0_15, mask);
__m256i wsa8_15= _mm256_and_si256(ws8_15, mask);
__m256i wsa16_15= _mm256_and_si256(ws16_15, mask);
__m256i wsa24_15= _mm256_and_si256(ws24_15, mask);
__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15);
__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15);
__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15);
__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15);
acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24);
__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]);
__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]);
__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]);
__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]);
__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]);
__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]);
__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]);
__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]);
__m256i ws0_0 = _mm256_srli_epi32(w0, 0);
__m256i ws8_0 = _mm256_srli_epi32(w8, 0);
__m256i ws16_0 = _mm256_srli_epi32(w16, 0);
__m256i ws24_0 = _mm256_srli_epi32(w24, 0);
__m256i wsa0_0= _mm256_and_si256(ws0_0, mask);
__m256i wsa8_0= _mm256_and_si256(ws8_0, mask);
__m256i wsa16_0= _mm256_and_si256(ws16_0, mask);
__m256i wsa24_0= _mm256_and_si256(ws24_0, mask);
__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0);
__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0);
__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0);
__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0);
acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24);
__m256i ws0_1 = _mm256_srli_epi32(w0, 2);
__m256i ws8_1 = _mm256_srli_epi32(w8, 2);
__m256i ws16_1 = _mm256_srli_epi32(w16, 2);
__m256i ws24_1 = _mm256_srli_epi32(w24, 2);
__m256i wsa0_1= _mm256_and_si256(ws0_1, mask);
__m256i wsa8_1= _mm256_and_si256(ws8_1, mask);
__m256i wsa16_1= _mm256_and_si256(ws16_1, mask);
__m256i wsa24_1= _mm256_and_si256(ws24_1, mask);
__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1);
__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1);
__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1);
__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1);
acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24);
__m256i ws0_2 = _mm256_srli_epi32(w0, 4);
__m256i ws8_2 = _mm256_srli_epi32(w8, 4);
__m256i ws16_2 = _mm256_srli_epi32(w16, 4);
__m256i ws24_2 = _mm256_srli_epi32(w24, 4);
__m256i wsa0_2= _mm256_and_si256(ws0_2, mask);
__m256i wsa8_2= _mm256_and_si256(ws8_2, mask);
__m256i wsa16_2= _mm256_and_si256(ws16_2, mask);
__m256i wsa24_2= _mm256_and_si256(ws24_2, mask);
__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2);
__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2);
__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2);
__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2);
acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24);
__m256i ws0_3 = _mm256_srli_epi32(w0, 6);
__m256i ws8_3 = _mm256_srli_epi32(w8, 6);
__m256i ws16_3 = _mm256_srli_epi32(w16, 6);
__m256i ws24_3 = _mm256_srli_epi32(w24, 6);
__m256i wsa0_3= _mm256_and_si256(ws0_3, mask);
__m256i wsa8_3= _mm256_and_si256(ws8_3, mask);
__m256i wsa16_3= _mm256_and_si256(ws16_3, mask);
__m256i wsa24_3= _mm256_and_si256(ws24_3, mask);
__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3);
__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3);
__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3);
__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3);
acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24);
__m256i ws0_4 = _mm256_srli_epi32(w0, 8);
__m256i ws8_4 = _mm256_srli_epi32(w8, 8);
__m256i ws16_4 = _mm256_srli_epi32(w16, 8);
__m256i ws24_4 = _mm256_srli_epi32(w24, 8);
__m256i wsa0_4= _mm256_and_si256(ws0_4, mask);
__m256i wsa8_4= _mm256_and_si256(ws8_4, mask);
__m256i wsa16_4= _mm256_and_si256(ws16_4, mask);
__m256i wsa24_4= _mm256_and_si256(ws24_4, mask);
__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4);
__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4);
__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4);
__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4);
acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24);
__m256i ws0_5 = _mm256_srli_epi32(w0, 10);
__m256i ws8_5 = _mm256_srli_epi32(w8, 10);
__m256i ws16_5 = _mm256_srli_epi32(w16, 10);
__m256i ws24_5 = _mm256_srli_epi32(w24, 10);
__m256i wsa0_5= _mm256_and_si256(ws0_5, mask);
__m256i wsa8_5= _mm256_and_si256(ws8_5, mask);
__m256i wsa16_5= _mm256_and_si256(ws16_5, mask);
__m256i wsa24_5= _mm256_and_si256(ws24_5, mask);
__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5);
__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5);
__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5);
__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5);
acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24);
__m256i ws0_6 = _mm256_srli_epi32(w0, 12);
__m256i ws8_6 = _mm256_srli_epi32(w8, 12);
__m256i ws16_6 = _mm256_srli_epi32(w16, 12);
__m256i ws24_6 = _mm256_srli_epi32(w24, 12);
__m256i wsa0_6= _mm256_and_si256(ws0_6, mask);
__m256i wsa8_6= _mm256_and_si256(ws8_6, mask);
__m256i wsa16_6= _mm256_and_si256(ws16_6, mask);
__m256i wsa24_6= _mm256_and_si256(ws24_6, mask);
__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6);
__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6);
__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6);
__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6);
acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24);
__m256i ws0_7 = _mm256_srli_epi32(w0, 14);
__m256i ws8_7 = _mm256_srli_epi32(w8, 14);
__m256i ws16_7 = _mm256_srli_epi32(w16, 14);
__m256i ws24_7 = _mm256_srli_epi32(w24, 14);
__m256i wsa0_7= _mm256_and_si256(ws0_7, mask);
__m256i wsa8_7= _mm256_and_si256(ws8_7, mask);
__m256i wsa16_7= _mm256_and_si256(ws16_7, mask);
__m256i wsa24_7= _mm256_and_si256(ws24_7, mask);
__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7);
__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7);
__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7);
__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7);
acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24);
}
__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]);
__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]);
__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]);
__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]);
__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]);
__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]);
__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]);
__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]);
__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0);
__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8);
__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16);
__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24);
}
}
}
}
}
}
#pragma omp barrier
const int ngs = m/gs;
for (int i = 0; i < n; i++) {
for (int j = 0; j < tt; j+=32){
__m256 acc0 = _mm256_setzero_ps();
__m256 acc8 = _mm256_setzero_ps();
__m256 acc16 = _mm256_setzero_ps();
__m256 acc24 = _mm256_setzero_ps();
for (int i1 = 0; i1 < ngs; i1++){
__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]);
__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]);
__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]);
__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]);
__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]);
__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]);
__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]);
__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]);
__m256 zs0 = _mm256_mul_ps(z0, s0);
__m256 zs8 = _mm256_mul_ps(z8, s8);
__m256 zs16 = _mm256_mul_ps(z16, s16);
__m256 zs24 = _mm256_mul_ps(z24, s24);
acc0 = _mm256_fmadd_ps(zs0, r, acc0);
acc8 = _mm256_fmadd_ps(zs8, r, acc8);
acc16 = _mm256_fmadd_ps(zs16, r, acc16);
acc24 = _mm256_fmadd_ps(zs24, r, acc24);
}
__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]);
__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]);
__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]);
__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]);
__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]);
__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]);
__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]);
__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]);
__m256 o10 = _mm256_add_ps(o0, acc0);
__m256 o18 = _mm256_add_ps(o8, acc8);
__m256 o116 = _mm256_add_ps(o16, acc16);
__m256 o124 = _mm256_add_ps(o24, acc24);
__m256 o20 = _mm256_add_ps(o10, b0);
__m256 o28 = _mm256_add_ps(o18, b8);
__m256 o216 = _mm256_add_ps(o116, b16);
__m256 o224 = _mm256_add_ps(o124, b24);
_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20);
_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28);
_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216);
_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224);
}
}
}
}
inline void qforward(const float* __restrict__ input,
const int* __restrict__ W,
const float* __restrict__ scales,
const float* __restrict__ zeros,
const float* __restrict__ bias,
const float* __restrict__ sums,
float* __restrict__ output,
int n,
int m,
int t) {
q2gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, 1, 1024, 32, 512, 64, 9);
}
inline void pack_input(float* A, float* B){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 1;
const int M = 4096;
const int nb = 1;
const int mb = 1024;
for(int i = 0; i < N; i+=nb){
for(int j = 0; j < M; j+=mb){
for(int jj = j; jj < mymin(j+mb, M); jj++){
for(int ii = i; ii < mymin(i+nb, N); ii++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
inline void pack_qw_inner(int* A, int* B, int cutoff){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 256;
const int M = 4096;
const int nb = 64;
int mb = 32;
for(int j = 0, tid = 0; j < M; j+=mb, tid++){
for(int i = 0; i < N; i+=nb){
for(int ii = i; ii < mymin(i+nb, N); ii++){
for(int jj = j; jj < mymin(j+mb, M); jj++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
inline void pack_qw(int* A, int* B){
pack_qw_inner(A, B, 65);
}
inline void pack_output(float* A, float* B){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 1;
const int M = 4096;
const int nb = 1;
const int mb = 32;
for(int i = 0; i < N; i+=nb){
for(int j = 0; j < M; j+=mb){
for(int ii = i; ii < mymin(i+nb, N); ii++){
for(int jj = j; jj < mymin(j+mb, M); jj++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
void print_parameters(){
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
outfile << 2 << "," << 1 << "," << 16 << "," << 32 << "," << 8 << "," << 8 << "," << 64 << ",";
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,149 @@
def load_int(to, address, const=True):
if const:
return f"const __m256i {to} = _mm256_loadu_si256({address});"
else:
return f"__m256i {to} = _mm256_loadu_si256({address});"
def load_fp(to, address, const=True):
if const:
return f"const __m256 {to} = _mm256_loadu_ps({address});"
else:
return f"__m256 {to} = _mm256_loadu_ps({address});"
# to = a * b + c
def vfma(to, a, b, c):
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
def vsrli(to, a, b):
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
def vand(to, a, b):
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
def vbroadcast_fp(to, a):
return f"const __m256 {to} = _mm256_set1_ps({a});"
def vbroadcast_int32(to, a):
return f"__m256i {to} = _mm256_set1_epi32({a});"
def vsetzero(to):
return f"__m256 {to} = _mm256_setzero_ps();"
def vcvtepi32_ps(to, a):
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
def _256extractf128_ps(to, a, imm):
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
def _256castps256_ps128(to, a):
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
def _add_ps(to, a, b):
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
def _movehl_ps(to, a, b):
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
def _shuffle_ps(to, a, b, imm):
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
def _cvtss_f32(to, a):
return f"const float {to} = _mm_cvtss_f32({a});"
def _reduce8_acc(a, b, c, d, e, f, g, h):
res = ""
res += _256extractf128_ps("hi_quad0", a, 1)
res += _256extractf128_ps("hi_quad1", b, 1)
res += _256extractf128_ps("hi_quad2", c, 1)
res += _256extractf128_ps("hi_quad3", d, 1)
res += _256extractf128_ps("hi_quad4", e, 1)
res += _256extractf128_ps("hi_quad5", f, 1)
res += _256extractf128_ps("hi_quad6", g, 1)
res += _256extractf128_ps("hi_quad7", h, 1)
res += _256castps256_ps128("lo_quad0", a)
res += _256castps256_ps128("lo_quad1", b)
res += _256castps256_ps128("lo_quad2", c)
res += _256castps256_ps128("lo_quad3", d)
res += _256castps256_ps128("lo_quad4", e)
res += _256castps256_ps128("lo_quad5", f)
res += _256castps256_ps128("lo_quad6", g)
res += _256castps256_ps128("lo_quad7", h)
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
res += _add_ps("sum0", "sum_dual0", "hi0")
res += _add_ps("sum1", "sum_dual1", "hi1")
res += _add_ps("sum2", "sum_dual2", "hi2")
res += _add_ps("sum3", "sum_dual3", "hi3")
res += _add_ps("sum4", "sum_dual4", "hi4")
res += _add_ps("sum5", "sum_dual5", "hi5")
res += _add_ps("sum6", "sum_dual6", "hi6")
res += _add_ps("sum7", "sum_dual7", "hi7")
res += _cvtss_f32(f"f{a}", "sum0")
res += _cvtss_f32(f"f{b}", "sum1")
res += _cvtss_f32(f"f{c}", "sum2")
res += _cvtss_f32(f"f{d}", "sum3")
res += _cvtss_f32(f"f{e}", "sum4")
res += _cvtss_f32(f"f{f}", "sum5")
res += _cvtss_f32(f"f{g}", "sum6")
res += _cvtss_f32(f"f{h}", "sum7")
return res
acc_idx = 0
def _reduce_add(a):
global acc_idx
res = ""
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
acc_idx += 1
return res

BIN
autogptq_extension/qigen/mmm Executable file

Binary file not shown.

View file

@ -0,0 +1,302 @@
#include <iostream>
#include "forward.h"
#include <cstring>
#include <algorithm>
#include <vector>
#include <chrono>
#include <fstream>
#define mymin(a,b) ((a)<(b)?(a):(b))
#define mymax(a,b) ((a)>(b)?(a):(b))
void print_matrix(std::string name, float* A, int N, int M){
std::cout<<name<<std::endl;
for(int i = 0; i < N; i++){
for(int j = 0; j < M; j++){
std::cout << A[i*M+j] << " ";
}
std::cout << std::endl;
}
std::cout<<std::endl;
}
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
// triple loop matmul and add bias
for (int i = 0; i < n; i++){
for (int j = 0; j < t; j++){
float sum = 0;
for (int k = 0; k < m; k++){
sum += A[i*m+k] * B[k*t+j];
}
C[i*t+j] += sum + bias[j];
}
}
}
void compute_reduction(float *in, float *out, int n, int m, int gs){
int ng;
if(gs == -1){
ng = 1;
gs = m;
}else{
ng = m/gs;
}
for(int i = 0; i < n; i++){
for(int j0 = 0; j0 < m; j0+=gs){
int j = j0/gs;
out[i*ng+j] = 0;
for(int j1 = j0; j1 < j0+gs; j1++){
out[i*ng+j] += in[i*m+j1];
}
}
}
}
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
for (int i1 = i0; i1 < i0+gs; i1++){
uint32_t acc = 0;
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
BQ[i1*m+j] = val;
}
}
}
}
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
if(bits == 3){
for (int i1 = i0; i1 < i0+gs; i1+=32){
uint32_t acc = 0;
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
int acc0 = 0, acc1 = 0, acc2 = 0;
acc0 |= temp0;
acc0 |= temp1;
acc0 |= temp2;
acc0 |= temp3;
acc0 |= temp4;
acc0 |= temp5;
acc0 |= temp6;
acc0 |= temp7;
acc0 |= temp8;
acc0 |= temp9;
acc0 |= temp10_0;
acc1 |= temp10_1;
acc1 |= temp11;
acc1 |= temp12;
acc1 |= temp13;
acc1 |= temp14;
acc1 |= temp15;
acc1 |= temp16;
acc1 |= temp17;
acc1 |= temp18;
acc1 |= temp19;
acc1 |= temp20;
acc1 |= temp21_0;
acc2 |= temp21_1;
acc2 |= temp22;
acc2 |= temp23;
acc2 |= temp24;
acc2 |= temp25;
acc2 |= temp26;
acc2 |= temp27;
acc2 |= temp28;
acc2 |= temp29;
acc2 |= temp30;
acc2 |= temp31;
BQ[(3*i1/32)*m+j] = acc0;
BQ[(3*i1/32+1)*m+j] = acc1;
BQ[(3*i1/32+2)*m+j] = acc2;
}
}else{
for (int i1 = i0; i1 < i0+gs; i1+=packed){
uint32_t acc = 0;
for (int i2 = i1; i2 < i1+packed; i2++){
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
acc = acc | (temp << (bits*(i2-i1)));
}
BQ[(i1/packed)*m+j] = acc;
}
}
}
}
}
int main(int argc, char *argv[]){
// read n m t from args
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
int n = atoi(argv[1]);
int m = atoi(argv[2]);
int t = atoi(argv[3]);
int bits = atoi(argv[4]);
int gs = atoi(argv[5]);
int ng;
if(gs == -1){
ng = 1;
}else{
ng = m/gs;
}
float* A = new float[n*m];
float* AB = new float[n*m];
float* B = new float[m*t];
float* BQS = new float[m*t];
float* scales = new float[t*ng];
float* zeros = new float[t*ng];
int* BQ = new int[m*t/8];
int* BQB = new int[m*t/8];
float* sums = new float[n*ng];
float* bias = new float[t];
float* C = new float[n*t];
float* CB = new float[n*t];
float* C2 = new float[n*t];
srand(1);
for (int i = 0; i < n*m; i++){
A[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t*m; i++){
B[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t; i++){
bias[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < n*t; i++){
C[i] = 0.0;
C2[i] = 0.0;
}
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
oracle_mmadd(A, BQS, bias, C, n, m, t);
pack_input(A,AB);
pack_qw(BQ,BQB);
pack_output(C,CB);
compute_reduction(A,sums,n,m,gs);
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
float norm = 0.0;
for (int i = 0; i < n*t; i++){
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
}
if(norm / (n*t) < 0.0001){
int iter = 30;
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
int num_runs = 15;
std::vector<long int> runs(num_runs);
for(int r = 0; r < num_runs; r++){
auto start = std::chrono::high_resolution_clock::now();
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
auto end = std::chrono::high_resolution_clock::now();
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
}
std::sort(runs.begin(), runs.end());
float cycles_final = runs[num_runs/2 + 1] / iter;
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}else{
float cycles_final = int(10e12);
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}
return 0;
}

View file

@ -0,0 +1,85 @@
def includes():
out = " \
#include <torch/all.h>\n \
#include <torch/python.h>\n \
#include <omp.h>\n \
#include <cmath>\n \
#include <immintrin.h>\n \
\n \
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
"
return out
def module(bits_list=[4, 2]):
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
for bits in bits_list:
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
# if oracle:
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
out += '}\n'
return out
def quant_scalar():
out = " \
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
//find scales and zeros arrays \n \
//quantize \n \
int pack = 32/bits;\n \
for (int j = 0; j < m; j++){\n \
for (int i = 0; i < n; i+=pack){\n \
uint32_t acc = 0;\n \
for (int ii = i; ii < i+pack; ii++){\n \
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
int temp = (int)ftemp;\n \
acc = acc | (temp << (bits*(ii-i)));\n \
}\n \
BQ[(i/pack)*m+j] = acc;\n \
//BQ[0] = acc;\n \
}\n \
}\n \
}\n \
\n \
void quant_scalar_cpu(\n \
torch::Tensor in, torch::Tensor out, \n \
torch::Tensor scales, torch::Tensor zeros, int bits\n \
) {\n \
\n \
int N = in.size(0);\n \
int M = in.size(1);\n \
\n \
float* input = in.data_ptr<float>(); \n \
float* s = scales.data_ptr<float>();\n \
float* z = zeros.data_ptr<float>();\n \
int* O = out.data_ptr<int>();\n \
\n \
quantize_scalar(input, O, s, z, N, M, bits);\n \
\n \
}\n"
return out

View file

@ -0,0 +1,37 @@
bits,nu,mu,tu,unroll,p,gs,time
4,1,16,16,1,8,-1,1.3814e+06
4,1,16,16,2,8,-1,1.44087e+06
4,1,16,16,4,8,-1,1.56173e+06
4,1,16,16,8,8,-1,1.41389e+06
3,1,16,16,5,8,-1,2.14748e+09
2,1,16,16,1,8,-1,1.09513e+06
2,1,16,16,2,8,-1,1.11322e+06
2,1,16,16,4,8,-1,1.12031e+06
2,1,16,16,8,8,-1,1.19086e+06
4,1,16,16,1,8,64,1.69111e+06
4,1,16,16,2,8,64,1.60056e+06
4,1,16,16,4,8,64,1.41263e+06
4,1,16,16,8,8,64,1.74572e+06
3,1,16,16,5,8,64,1.48062e+06
2,1,16,16,1,8,64,1.51234e+06
2,1,16,16,2,8,64,1.68108e+06
2,1,16,16,4,8,64,1.7624e+06
2,1,16,16,8,8,64,1.69563e+06
4,1,16,32,1,8,-1,1.24798e+06
4,1,16,32,2,8,-1,1.58421e+06
4,1,16,32,4,8,-1,2.10718e+06
4,1,16,32,8,8,-1,1.54288e+06
3,1,16,32,5,8,-1,2.14748e+09
2,1,16,32,1,8,-1,1.55906e+06
2,1,16,32,2,8,-1,1.58576e+06
2,1,16,32,4,8,-1,1.57993e+06
2,1,16,32,8,8,-1,1.80443e+06
4,1,16,32,1,8,64,1.58354e+06
4,1,16,32,2,8,64,1.63248e+06
4,1,16,32,4,8,64,1.91902e+06
4,1,16,32,8,8,64,1.9243e+06
3,1,16,32,5,8,64,1.33812e+06
2,1,16,32,1,8,64,1.77522e+06
2,1,16,32,2,8,64,1.54702e+06
2,1,16,32,4,8,64,1.78772e+06
2,1,16,32,8,8,64,1.49612e+06
1 bits nu mu tu unroll p gs time
2 4 1 16 16 1 8 -1 1.3814e+06
3 4 1 16 16 2 8 -1 1.44087e+06
4 4 1 16 16 4 8 -1 1.56173e+06
5 4 1 16 16 8 8 -1 1.41389e+06
6 3 1 16 16 5 8 -1 2.14748e+09
7 2 1 16 16 1 8 -1 1.09513e+06
8 2 1 16 16 2 8 -1 1.11322e+06
9 2 1 16 16 4 8 -1 1.12031e+06
10 2 1 16 16 8 8 -1 1.19086e+06
11 4 1 16 16 1 8 64 1.69111e+06
12 4 1 16 16 2 8 64 1.60056e+06
13 4 1 16 16 4 8 64 1.41263e+06
14 4 1 16 16 8 8 64 1.74572e+06
15 3 1 16 16 5 8 64 1.48062e+06
16 2 1 16 16 1 8 64 1.51234e+06
17 2 1 16 16 2 8 64 1.68108e+06
18 2 1 16 16 4 8 64 1.7624e+06
19 2 1 16 16 8 8 64 1.69563e+06
20 4 1 16 32 1 8 -1 1.24798e+06
21 4 1 16 32 2 8 -1 1.58421e+06
22 4 1 16 32 4 8 -1 2.10718e+06
23 4 1 16 32 8 8 -1 1.54288e+06
24 3 1 16 32 5 8 -1 2.14748e+09
25 2 1 16 32 1 8 -1 1.55906e+06
26 2 1 16 32 2 8 -1 1.58576e+06
27 2 1 16 32 4 8 -1 1.57993e+06
28 2 1 16 32 8 8 -1 1.80443e+06
29 4 1 16 32 1 8 64 1.58354e+06
30 4 1 16 32 2 8 64 1.63248e+06
31 4 1 16 32 4 8 64 1.91902e+06
32 4 1 16 32 8 8 64 1.9243e+06
33 3 1 16 32 5 8 64 1.33812e+06
34 2 1 16 32 1 8 64 1.77522e+06
35 2 1 16 32 2 8 64 1.54702e+06
36 2 1 16 32 4 8 64 1.78772e+06
37 2 1 16 32 8 8 64 1.49612e+06

19
docs/NEWS_OR_UPDATE.md Normal file
View file

@ -0,0 +1,19 @@
## <center>News or Update</center>
- 2023-08-23 - (News) - 🤗 Transformers, optimum and peft have integrated `auto-gptq`, so now running and training GPTQ models can be more available to everyone! See [this blog](https://huggingface.co/blog/gptq-integration) and it's resources for more details!
- 2023-08-21 - (News) - Team of Qwen officially released 4bit quantized version of Qwen-7B based on `auto-gptq`, and provided [a detailed benchmark results](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4#%E9%87%8F%E5%8C%96-quantization)
- 2023-08-06 - (Update) - Support exllama's q4 CUDA kernel to have at least 1.3x speed up for int4 quantized models when doing inference.
- 2023-08-04 - (Update) - Support RoCm so that AMD GPU users can use auto-gptq with CUDA extensions.
- 2023-07-26 - (Update) - An elegant [PPL benchmark script](examples/benchmark/perplexity.py) to get results that can be fairly compared with other libraries such as `llama.cpp`.
- 2023-06-05 - (Update) - Integrate with 🤗 peft to use gptq quantized model to train adapters, support LoRA, AdaLoRA, AdaptionPrompt, etc.
- 2023-05-30 - (Update) - support download/upload quantized model from/to 🤗 Hub.
- 2023-05-27 - (Update) - Support quantization and inference for `gpt_bigcode`, `codegen` and `RefineWeb/RefineWebModel`(falcon) model types.
- 2023-05-04 - (Update) - Support using faster cuda kernel when `not desc_act or group_size == -1`
- 2023-04-29 - (Update) - Support loading quantized model from arbitrary quantize_config and model_basename.
- 2023-04-28 - (Update) - Support CPU offload and quantize/inference on multiple devices, support `gpt2` type models.
- 2023-04-26 - (Update) - Using `triton` to speed up inference is now supported.
- 2023-04-25 - (News&Update) - [MOSS](https://github.com/OpenLMLab/MOSS) is an open-source tool-augmented conversational language model from Fudan University, quantization is now supported in AutoGPTQ.
- 2023-04-23 - (Update) - Support evaluation on multiple (down-stream) tasks such as: language-modeling, text-classification, text-summarization.
- 2023-04-22 - (News) - qwopqwop200's [AutoGPTQ-triton](https://github.com/qwopqwop200/AutoGPTQ-triton) provides faster speed to integrate with quantized model, for everyone who can access to triton, try and enjoy yourself!
- 2023-04-20 - (News) - AutoGPTQ is automatically compatible with Stability-AI's newly released `gpt_neox` type model family [StableLM](https://github.com/Stability-AI/StableLM).
- 2023-04-16 - (Update) - Support quantization and inference for `bloom`, `gpt_neox`, `gptj`, `llama` and `opt`.

Some files were not shown because too many files have changed in this diff Show more