Fix the exllamav2_HF and exllamav3_HF loaders

This commit is contained in:
oobabooga 2025-04-21 18:32:23 -07:00
parent 15989c2ed8
commit 8320190184

View file

@ -15,6 +15,9 @@ from modules import shared
from modules.logging_colors import logger
from modules.torch_utils import get_device
original_init = transformers.GenerationConfig.__init__
original_get_logits_processor = transformers.GenerationMixin._get_logits_processor
global_scores = None
@ -484,7 +487,7 @@ def get_logits_processor_patch(self, **kwargs):
generation_config.temperature = float(generation_config.temperature) # Must be float
# Get the original warpers
warpers = self._get_logits_processor_old(**kwargs)
warpers = original_get_logits_processor(self, **kwargs)
for i in range(len(warpers) - 1, -1, -1):
# Replace temperature with our modified class.
@ -674,7 +677,7 @@ def get_logits_processor_patch(self, **kwargs):
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
original_init(self, **kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
@ -702,8 +705,5 @@ def generation_config_init_patch(self, **kwargs):
def hijack_samplers():
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
transformers.GenerationConfig.__init__ = generation_config_init_patch