Merge pull request #355 from PanQiWei/fix_pack_model_use_exllamav2

import exllama QuantLinear instead of exllamav2's in `pack_model`
This commit is contained in:
潘其威(William) 2023-09-27 11:06:35 +08:00 committed by GitHub
commit 51c043c6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -188,7 +188,7 @@ def pack_model(
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=False, disable_exllamav2=True)
if force_layer_back_to_cpu:
model.to(CPU)
@ -196,7 +196,7 @@ def pack_model(
logger.info('Packing model...')
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act)
make_quant(model, quantizers, bits, group_size, use_triton=use_triton, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act, disable_exllama=False, disable_exllamav2=True)
qlayers = find_layers(model, [QuantLinear])
for name in qlayers:
logger.info(name)