fix bug disable exlllama

This commit is contained in:
qwopqwop200 2023-08-07 16:28:30 +09:00 committed by GitHub
parent 25972d65bf
commit 2f48780165
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View file

@ -235,10 +235,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=False,
**kwargs
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):

View file

@ -135,12 +135,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=False,
**kwargs
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):