mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-08 14:46:14 -04:00
Make transformers 4.49 functional
This commit is contained in:
parent
16f4f1a1c3
commit
dba17c40fc
1 changed files with 9 additions and 9 deletions
|
@ -5,7 +5,7 @@ import random
|
|||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsWarper
|
||||
from transformers import LogitsProcessor
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
|
@ -19,7 +19,7 @@ from modules.models import get_device
|
|||
global_scores = None
|
||||
|
||||
|
||||
class TemperatureLogitsWarperCustom(LogitsWarper):
|
||||
class TemperatureLogitsWarperCustom(LogitsProcessor):
|
||||
'''
|
||||
A copy of the original Transformers temperature logits warper.
|
||||
'''
|
||||
|
@ -42,7 +42,7 @@ class TemperatureLogitsWarperCustom(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class DynamicTemperatureLogitsWarper(LogitsWarper):
|
||||
class DynamicTemperatureLogitsWarper(LogitsProcessor):
|
||||
'''
|
||||
Dynamic temperature.
|
||||
'''
|
||||
|
@ -100,7 +100,7 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class QuadraticSamplingLogitsWarper(LogitsWarper):
|
||||
class QuadraticSamplingLogitsWarper(LogitsProcessor):
|
||||
'''
|
||||
Quadratic sampling with smoothing factor and smoothing curve parameters.
|
||||
'''
|
||||
|
@ -127,7 +127,7 @@ class QuadraticSamplingLogitsWarper(LogitsWarper):
|
|||
return transformed_logits
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
class TailFreeLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
|
@ -167,7 +167,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class TopALogitsWarper(LogitsWarper):
|
||||
class TopALogitsWarper(LogitsProcessor):
|
||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
top_a = float(top_a)
|
||||
if top_a < 0 or top_a > 1.0:
|
||||
|
@ -194,7 +194,7 @@ class TopALogitsWarper(LogitsWarper):
|
|||
|
||||
|
||||
# Exclude Top Choices (XTC)
|
||||
class XTCLogitsWarper(LogitsWarper):
|
||||
class XTCLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||
self.threshold = threshold
|
||||
self.probability = probability
|
||||
|
@ -312,7 +312,7 @@ class DRYLogitsProcessor(LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
class MirostatLogitsWarper(LogitsWarper):
|
||||
class MirostatLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if mirostat_mode not in [2]:
|
||||
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
||||
|
@ -361,7 +361,7 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class SpyLogitsWarper(LogitsWarper):
|
||||
class SpyLogitsWarper(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue