Fix CFG with ExLlamaV2_HF (closes #6937)

This commit is contained in:
oobabooga 2025-04-30 18:43:45 -07:00
parent ec2e641749
commit 55283bb8f1

View file

@ -65,7 +65,7 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin):
elif kv_cache_type == 'q4':
cache_type = ExLlamaV2Cache_Q4
else:
raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
raise ValueError(f"Invalid cache type for ExLlamaV2: {kv_cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
# Use TP if specified
if shared.args.enable_tp:
@ -78,12 +78,10 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin):
self.past_seq = None
if shared.args.cfg_cache:
if shared.args.cache_8bit:
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
elif shared.args.cache_4bit:
self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)
if shared.args.enable_tp:
self.ex_cache_negative = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
self.ex_cache_negative = cache_type(self.ex_model, lazy=shared.args.autosplit)
self.past_seq_negative = None