Make transformers 4.49 functional

This commit is contained in:
oobabooga 2025-02-17 17:31:11 -08:00
parent 16f4f1a1c3
commit dba17c40fc

View file

@ -5,7 +5,7 @@ import random
import torch
import transformers
from transformers import LogitsWarper
from transformers import LogitsProcessor
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
@ -19,7 +19,7 @@ from modules.models import get_device
global_scores = None
class TemperatureLogitsWarperCustom(LogitsWarper):
class TemperatureLogitsWarperCustom(LogitsProcessor):
'''
A copy of the original Transformers temperature logits warper.
'''
@ -42,7 +42,7 @@ class TemperatureLogitsWarperCustom(LogitsWarper):
return scores
class DynamicTemperatureLogitsWarper(LogitsWarper):
class DynamicTemperatureLogitsWarper(LogitsProcessor):
'''
Dynamic temperature.
'''
@ -100,7 +100,7 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
return scores
class QuadraticSamplingLogitsWarper(LogitsWarper):
class QuadraticSamplingLogitsWarper(LogitsProcessor):
'''
Quadratic sampling with smoothing factor and smoothing curve parameters.
'''
@ -127,7 +127,7 @@ class QuadraticSamplingLogitsWarper(LogitsWarper):
return transformed_logits
class TailFreeLogitsWarper(LogitsWarper):
class TailFreeLogitsWarper(LogitsProcessor):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
@ -167,7 +167,7 @@ class TailFreeLogitsWarper(LogitsWarper):
return scores
class TopALogitsWarper(LogitsWarper):
class TopALogitsWarper(LogitsProcessor):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
@ -194,7 +194,7 @@ class TopALogitsWarper(LogitsWarper):
# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsWarper):
class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
@ -312,7 +312,7 @@ class DRYLogitsProcessor(LogitsProcessor):
return scores
class MirostatLogitsWarper(LogitsWarper):
class MirostatLogitsWarper(LogitsProcessor):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
@ -361,7 +361,7 @@ class MirostatLogitsWarper(LogitsWarper):
return scores
class SpyLogitsWarper(LogitsWarper):
class SpyLogitsWarper(LogitsProcessor):
def __init__(self):
pass