From dba17c40fc67fd4e64a26214c47d745bf5a42d18 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 17 Feb 2025 17:31:11 -0800 Subject: [PATCH] Make transformers 4.49 functional --- modules/sampler_hijack.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index d202af1f..e0df49c3 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -5,7 +5,7 @@ import random import torch import transformers -from transformers import LogitsWarper +from transformers import LogitsProcessor from transformers.generation.logits_process import ( LogitNormalization, LogitsProcessor, @@ -19,7 +19,7 @@ from modules.models import get_device global_scores = None -class TemperatureLogitsWarperCustom(LogitsWarper): +class TemperatureLogitsWarperCustom(LogitsProcessor): ''' A copy of the original Transformers temperature logits warper. ''' @@ -42,7 +42,7 @@ class TemperatureLogitsWarperCustom(LogitsWarper): return scores -class DynamicTemperatureLogitsWarper(LogitsWarper): +class DynamicTemperatureLogitsWarper(LogitsProcessor): ''' Dynamic temperature. ''' @@ -100,7 +100,7 @@ class DynamicTemperatureLogitsWarper(LogitsWarper): return scores -class QuadraticSamplingLogitsWarper(LogitsWarper): +class QuadraticSamplingLogitsWarper(LogitsProcessor): ''' Quadratic sampling with smoothing factor and smoothing curve parameters. ''' @@ -127,7 +127,7 @@ class QuadraticSamplingLogitsWarper(LogitsWarper): return transformed_logits -class TailFreeLogitsWarper(LogitsWarper): +class TailFreeLogitsWarper(LogitsProcessor): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): tfs = float(tfs) if tfs < 0 or tfs > 1.0: @@ -167,7 +167,7 @@ class TailFreeLogitsWarper(LogitsWarper): return scores -class TopALogitsWarper(LogitsWarper): +class TopALogitsWarper(LogitsProcessor): def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): top_a = float(top_a) if top_a < 0 or top_a > 1.0: @@ -194,7 +194,7 @@ class TopALogitsWarper(LogitsWarper): # Exclude Top Choices (XTC) -class XTCLogitsWarper(LogitsWarper): +class XTCLogitsWarper(LogitsProcessor): def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): self.threshold = threshold self.probability = probability @@ -312,7 +312,7 @@ class DRYLogitsProcessor(LogitsProcessor): return scores -class MirostatLogitsWarper(LogitsWarper): +class MirostatLogitsWarper(LogitsProcessor): def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): if mirostat_mode not in [2]: raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}") @@ -361,7 +361,7 @@ class MirostatLogitsWarper(LogitsWarper): return scores -class SpyLogitsWarper(LogitsWarper): +class SpyLogitsWarper(LogitsProcessor): def __init__(self): pass