text-generation-webui-mirror/modules/sampler_hijack.py
2025-03-14 16:45:11 -03:00

709 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import math
import pprint
import random
import torch
import transformers
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
LogitsProcessorList
)
from modules import shared
from modules.logging_colors import logger
from modules.models import get_device
global_scores = None
class TemperatureLogitsWarperCustom(LogitsProcessor):
'''
A copy of the original Transformers temperature logits warper.
'''
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
self.temperature = temperature
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores
class DynamicTemperatureLogitsWarper(LogitsProcessor):
'''
Dynamic temperature.
'''
def __init__(self, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
# Calculate entropy of the softmax probabilities
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
# Guard against future possible division by zero
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
# Any logits which are not -Infinity will be considered for calculating max entropy.
num_valid_tokens = torch.sum(scores > -float('inf')).item()
# Now, calculate the max entropy by using only the valid tokens' count
max_entropy = math.log(num_valid_tokens)
# Guard against future possible division by zero
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
# Normalize the entropy
normalized_entropy = entropy / max_entropy
# Map the normalized entropy to the desired temperature range using the power function
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
# Apply the dynamically calculated temperature scaling
scores = scores / dyn_temp
# print("----------------------\nTemperature from generation_config:", self.temperature)
# print("min_temp:", min_temp)
# print("max_temp:", max_temp)
# print("Entropy:", entropy.item())
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
# print("Normalized Entropy:", normalized_entropy.item())
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
# print("----------------------")
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
return scores
class QuadraticSamplingLogitsWarper(LogitsProcessor):
'''
Quadratic sampling with smoothing factor and smoothing curve parameters.
'''
def __init__(self, smoothing_factor, smoothing_curve):
self.smoothing_factor = smoothing_factor
self.smoothing_curve = smoothing_curve
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Compute necessary values
max_logit = scores.max()
diff = scores - max_logit
k = (3 - self.smoothing_curve) / 2
s = (self.smoothing_curve - 1) / 2
# Apply transformation to non-negative infinity values
transformed_logits = torch.where(
scores != float('-inf'),
-(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit,
scores
)
return transformed_logits
class TailFreeLogitsWarper(LogitsProcessor):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Compute second derivative normalized CDF
d2 = probs.diff().diff().abs()
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
(
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
sorted_indices_to_remove,
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
),
dim=-1,
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TopALogitsWarper(LogitsProcessor):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
self.top_a = top_a
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TopNSigmaLogitsWarper(LogitsProcessor):
def __init__(self, n_sigma: float = 2.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
"""
Initialize Top-nσ Sampling logits warper.
Args:
n_sigma: The threshold multiplier for standard deviation
filter_value: Value to assign to filtered logits
min_tokens_to_keep: Minimum number of tokens to keep
"""
if n_sigma < 0:
raise ValueError(f"`n_sigma` must be a non-negative float, but is {n_sigma}")
self.n_sigma = n_sigma
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate max of logits
max_logit = torch.max(scores, dim=-1, keepdim=True)[0]
# Calculate standard deviation only on finite values
finite_mask = torch.isfinite(scores)
finite_scores = scores.masked_fill(~finite_mask, 0.0)
std_logit = torch.std(finite_scores, dim=-1, keepdim=True)
# Create mask where tokens with logits >= max_logit - n_sigma * std_logit are kept
threshold = max_logit - self.n_sigma * std_logit
indices_to_remove = scores < threshold
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep tokens
top_k_indices = torch.topk(scores, self.min_tokens_to_keep, dim=-1)[1]
indices_to_remove.scatter_(-1, top_k_indices, False)
# Apply mask by setting filtered tokens to filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.special_token_ids = [
shared.tokenizer.encode("\n")[-1],
]
if shared.tokenizer.eos_token_id is not None:
self.special_token_ids.append(shared.tokenizer.eos_token_id)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# `random` returns values in the half-open range [0, 1), so setting `probability`
# to 0 means the sampler never takes action, while setting it to 1 means the sampler
# always takes action.
#
# Note that while XTC is most intuitively described as "if multiple tokens meet
# the threshold, then with probability...", reversing the two conditions is logically
# equivalent, and improves performance because processing can immediately be stopped
# if the random check fails.
if random.random() >= self.probability:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
# This operation sets exactly those indices to `True` for which the next index has
# probability above the threshold. Since `probs` is sorted, those are the indices
# of all tokens that meet the threshold, *except* the least probable one.
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
# Convert sorted_indices_to_remove to the original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# If newline or EOS tokens would be removed, return the original scores
if indices_to_remove[:, self.special_token_ids].any():
return scores
# Otherwise, remove tokens with the mask
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class DRYLogitsProcessor(LogitsProcessor):
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
self.multiplier = multiplier
self.base = base
self.allowed_length = allowed_length
self.sequence_breakers = sequence_breakers
self._range = _range
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self._range > 0:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
# Use normal Python data types for improved performance
input_ids = input_ids_row.tolist()
last_token = input_ids[-1]
if last_token in self.sequence_breakers:
continue
# Exclude the last token as it always matches.
match_indices = []
for idx, val in enumerate(input_ids[:-1]):
if val == last_token:
match_indices.append(idx)
# Stores the maximum matching sequence length
# for each token immediately following the sequence in the input.
match_lengths = {}
for i in match_indices:
next_token = input_ids[i + 1]
if next_token in self.sequence_breakers:
continue
# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
while match_length < 50:
j = i - match_length
if j < 0:
# Start of input reached.
break
previous_token = input_ids[-(match_length + 1)]
if input_ids[j] != previous_token:
# Start of match reached.
break
if previous_token in self.sequence_breakers:
# Sequence-breaking token reached.
break
match_length += 1
if next_token in match_lengths:
match_lengths[next_token] = max(match_length, match_lengths[next_token])
else:
match_lengths[next_token] = match_length
# Apply penalties.
for token, match_length in match_lengths.items():
if match_length >= self.allowed_length:
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
scores_row[token] -= penalty
return scores
class MirostatLogitsWarper(LogitsProcessor):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
self.mirostat_mode = mirostat_mode
self.mirostat_eta = mirostat_eta
self.mirostat_tau = mirostat_tau
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
self.mu = 2 * self.mirostat_tau
self.e = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
logits = scores[0]
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates
# Truncate the words with surprise values greater than mu
for i, candidate in enumerate(prob_original):
if candidate > 0 and -math.log2(candidate) > self.mu:
if (i == 0):
sorted_logits = sorted_logits[:1]
else:
sorted_logits = sorted_logits[:i]
break
# Normalize the probabilities of the remaining words
prob_topk = torch.softmax(sorted_logits, dim=0)
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
device = get_device()
if device:
prob_topk = prob_topk.to(device)
prev_i = prev_i.to(device)
observed_surprise = -math.log2(prob_topk[prev_i])
self.e = observed_surprise - self.mirostat_tau
# Update mu using the learning rate and error
self.mu -= self.mirostat_eta * self.e
sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
sorted_indices_to_remove[prev_i] = False
indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0))
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class SpyLogitsWarper(LogitsProcessor):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
global global_scores
global_scores = scores
return scores
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
def __init__(self, penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
self.penalty = penalty
self._range = _range
def apply_repetition_penalty(self, input_ids_row, scores_row):
unique_ids = torch.unique(input_ids_row)
score = torch.gather(scores_row, 0, unique_ids)
# Apply multiplicative repetition penalty
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)
return scores
class PresencePenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, presence_penalty: float, _range: int):
self.presence_penalty = presence_penalty
self._range = _range
def apply_presence_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
# Apply presence penalty
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
presence_penalty = raw_presence_penalty * self.presence_penalty
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
return scores
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, frequency_penalty: float, _range: int):
self.frequency_penalty = frequency_penalty
self._range = _range
def apply_frequency_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
# Apply frequency penalty
raw_frequency_penalty = counts.to(scores_row.dtype)
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
return scores
def get_logits_processor_patch(self, **kwargs):
generation_config = kwargs['generation_config']
# Parameter sanitization
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature) # Must be float
# Get the original warpers
warpers = self._get_logits_processor_old(**kwargs)
for i in range(len(warpers) - 1, -1, -1):
# Replace temperature with our modified class.
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperCustom(
generation_config.temperature,
)
# Stuff we don't need
elif warpers[i].__class__.__name__ in ['RepetitionPenaltyLogitsProcessor']:
del warpers[i]
# Add custom warpers
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
warpers_to_add.append(
RepetitionPenaltyLogitsProcessorWithRange(
penalty=generation_config.repetition_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
warpers_to_add.append(
PresencePenaltyLogitsProcessor(
presence_penalty=generation_config.presence_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
warpers_to_add.append(
FrequencyPenaltyLogitsProcessor(
frequency_penalty=generation_config.frequency_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {
shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings
}
warpers.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(
TailFreeLogitsWarper(
tfs=generation_config.tfs,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(
TopALogitsWarper(
top_a=generation_config.top_a,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.top_n_sigma is not None and generation_config.top_n_sigma > 0.0:
warpers_to_add.append(
TopNSigmaLogitsWarper(
n_sigma=generation_config.top_n_sigma,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
warpers_to_add.append(
XTCLogitsWarper(
threshold=generation_config.xtc_threshold,
probability=generation_config.xtc_probability,
)
)
if generation_config.dynamic_temperature:
warpers_to_add.append(
DynamicTemperatureLogitsWarper(
dynatemp_low=generation_config.dynatemp_low,
dynatemp_high=generation_config.dynatemp_high,
dynatemp_exponent=generation_config.dynatemp_exponent,
)
)
if generation_config.smoothing_factor > 0:
warpers_to_add.append(
QuadraticSamplingLogitsWarper(
smoothing_factor=generation_config.smoothing_factor,
smoothing_curve=generation_config.smoothing_curve
)
)
if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
warpers_to_add.append(
MirostatLogitsWarper(
mirostat_mode=generation_config.mirostat_mode,
mirostat_eta=generation_config.mirostat_eta,
mirostat_tau=generation_config.mirostat_tau,
min_tokens_to_keep=min_tokens_to_keep
)
)
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
normalize = warpers.pop(-1)
else:
normalize = None
warpers += warpers_to_add
# Sort the samplers.
sampler_priority = generation_config.sampler_priority
# Handle temperature_last
if generation_config.temperature_last:
for param_name in ['temperature', 'dynamic_temperature', 'quadratic_sampling']:
if param_name in sampler_priority:
index = sampler_priority.index(param_name)
sampler_priority.append(sampler_priority.pop(index))
else:
sampler_priority.append(param_name)
class_name_to_nickname = {
'DynamicTemperatureLogitsWarper': 'dynamic_temperature',
'EpsilonLogitsWarper': 'epsilon_cutoff',
'EtaLogitsWarper': 'eta_cutoff',
'MinPLogitsWarper': 'min_p',
'MirostatLogitsWarper': 'mirostat',
'QuadraticSamplingLogitsWarper': 'quadratic_sampling',
'TailFreeLogitsWarper': 'tfs',
'TemperatureLogitsWarperCustom': 'temperature',
'TopALogitsWarper': 'top_a',
'TopNSigmaLogitsWarper': 'top_n_sigma',
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p',
'XTCLogitsWarper': 'xtc',
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
'PresencePenaltyLogitsProcessor': 'presence_penalty',
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
'DRYLogitsProcessor': 'dry',
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
}
def custom_sort_key(obj):
class_name = obj.__class__.__name__
# Return -1 if class_name is not mapped
if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority:
return -1
return sampler_priority.index(class_name_to_nickname[class_name])
# Sort the list using the custom key function
warpers = sorted(warpers, key=custom_sort_key)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
print()
if normalize is not None:
warpers.append(normalize)
warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers)
return warpers
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0)
self.dry_base = kwargs.pop("dry_base", 1.75)
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
self.xtc_probability = kwargs.pop("xtc_probability", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
def hijack_samplers():
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
transformers.GenerationConfig.__init__ = generation_config_init_patch