Make transformers 4.49 functional

This commit is contained in:
oobabooga 2025-02-17 17:31:11 -08:00
parent 16f4f1a1c3
commit dba17c40fc

View file

@ -5,7 +5,7 @@ import random
import torch import torch
import transformers import transformers
from transformers import LogitsWarper from transformers import LogitsProcessor
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitNormalization, LogitNormalization,
LogitsProcessor, LogitsProcessor,
@ -19,7 +19,7 @@ from modules.models import get_device
global_scores = None global_scores = None
class TemperatureLogitsWarperCustom(LogitsWarper): class TemperatureLogitsWarperCustom(LogitsProcessor):
''' '''
A copy of the original Transformers temperature logits warper. A copy of the original Transformers temperature logits warper.
''' '''
@ -42,7 +42,7 @@ class TemperatureLogitsWarperCustom(LogitsWarper):
return scores return scores
class DynamicTemperatureLogitsWarper(LogitsWarper): class DynamicTemperatureLogitsWarper(LogitsProcessor):
''' '''
Dynamic temperature. Dynamic temperature.
''' '''
@ -100,7 +100,7 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
return scores return scores
class QuadraticSamplingLogitsWarper(LogitsWarper): class QuadraticSamplingLogitsWarper(LogitsProcessor):
''' '''
Quadratic sampling with smoothing factor and smoothing curve parameters. Quadratic sampling with smoothing factor and smoothing curve parameters.
''' '''
@ -127,7 +127,7 @@ class QuadraticSamplingLogitsWarper(LogitsWarper):
return transformed_logits return transformed_logits
class TailFreeLogitsWarper(LogitsWarper): class TailFreeLogitsWarper(LogitsProcessor):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs) tfs = float(tfs)
if tfs < 0 or tfs > 1.0: if tfs < 0 or tfs > 1.0:
@ -167,7 +167,7 @@ class TailFreeLogitsWarper(LogitsWarper):
return scores return scores
class TopALogitsWarper(LogitsWarper): class TopALogitsWarper(LogitsProcessor):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a) top_a = float(top_a)
if top_a < 0 or top_a > 1.0: if top_a < 0 or top_a > 1.0:
@ -194,7 +194,7 @@ class TopALogitsWarper(LogitsWarper):
# Exclude Top Choices (XTC) # Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsWarper): class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold self.threshold = threshold
self.probability = probability self.probability = probability
@ -312,7 +312,7 @@ class DRYLogitsProcessor(LogitsProcessor):
return scores return scores
class MirostatLogitsWarper(LogitsWarper): class MirostatLogitsWarper(LogitsProcessor):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]: if mirostat_mode not in [2]:
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}") raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
@ -361,7 +361,7 @@ class MirostatLogitsWarper(LogitsWarper):
return scores return scores
class SpyLogitsWarper(LogitsWarper): class SpyLogitsWarper(LogitsProcessor):
def __init__(self): def __init__(self):
pass pass