mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-08 22:56:24 -04:00
Make transformers 4.49 functional
This commit is contained in:
parent
16f4f1a1c3
commit
dba17c40fc
1 changed files with 9 additions and 9 deletions
|
@ -5,7 +5,7 @@ import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsProcessor
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitNormalization,
|
LogitNormalization,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
|
@ -19,7 +19,7 @@ from modules.models import get_device
|
||||||
global_scores = None
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarperCustom(LogitsWarper):
|
class TemperatureLogitsWarperCustom(LogitsProcessor):
|
||||||
'''
|
'''
|
||||||
A copy of the original Transformers temperature logits warper.
|
A copy of the original Transformers temperature logits warper.
|
||||||
'''
|
'''
|
||||||
|
@ -42,7 +42,7 @@ class TemperatureLogitsWarperCustom(LogitsWarper):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class DynamicTemperatureLogitsWarper(LogitsWarper):
|
class DynamicTemperatureLogitsWarper(LogitsProcessor):
|
||||||
'''
|
'''
|
||||||
Dynamic temperature.
|
Dynamic temperature.
|
||||||
'''
|
'''
|
||||||
|
@ -100,7 +100,7 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class QuadraticSamplingLogitsWarper(LogitsWarper):
|
class QuadraticSamplingLogitsWarper(LogitsProcessor):
|
||||||
'''
|
'''
|
||||||
Quadratic sampling with smoothing factor and smoothing curve parameters.
|
Quadratic sampling with smoothing factor and smoothing curve parameters.
|
||||||
'''
|
'''
|
||||||
|
@ -127,7 +127,7 @@ class QuadraticSamplingLogitsWarper(LogitsWarper):
|
||||||
return transformed_logits
|
return transformed_logits
|
||||||
|
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsProcessor):
|
||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
tfs = float(tfs)
|
tfs = float(tfs)
|
||||||
if tfs < 0 or tfs > 1.0:
|
if tfs < 0 or tfs > 1.0:
|
||||||
|
@ -167,7 +167,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class TopALogitsWarper(LogitsWarper):
|
class TopALogitsWarper(LogitsProcessor):
|
||||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
top_a = float(top_a)
|
top_a = float(top_a)
|
||||||
if top_a < 0 or top_a > 1.0:
|
if top_a < 0 or top_a > 1.0:
|
||||||
|
@ -194,7 +194,7 @@ class TopALogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
|
|
||||||
# Exclude Top Choices (XTC)
|
# Exclude Top Choices (XTC)
|
||||||
class XTCLogitsWarper(LogitsWarper):
|
class XTCLogitsWarper(LogitsProcessor):
|
||||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.probability = probability
|
self.probability = probability
|
||||||
|
@ -312,7 +312,7 @@ class DRYLogitsProcessor(LogitsProcessor):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class MirostatLogitsWarper(LogitsWarper):
|
class MirostatLogitsWarper(LogitsProcessor):
|
||||||
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
if mirostat_mode not in [2]:
|
if mirostat_mode not in [2]:
|
||||||
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
||||||
|
@ -361,7 +361,7 @@ class MirostatLogitsWarper(LogitsWarper):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class SpyLogitsWarper(LogitsWarper):
|
class SpyLogitsWarper(LogitsProcessor):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue