mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 06:06:20 -04:00
commit
4fa52a1302
1 changed files with 47 additions and 38 deletions
|
@ -55,31 +55,6 @@ class LlamaServer:
|
|||
return result.get("content", "")
|
||||
|
||||
def prepare_payload(self, state):
|
||||
# Prepare DRY
|
||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
|
||||
# Prepare the sampler order
|
||||
samplers = state["sampler_priority"]
|
||||
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||
penalty_found = False
|
||||
filtered_samplers = []
|
||||
for s in samplers:
|
||||
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||
filtered_samplers.append(s.strip())
|
||||
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||
filtered_samplers.append("penalties")
|
||||
penalty_found = True
|
||||
|
||||
samplers = filtered_samplers
|
||||
|
||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||
if state["temperature_last"] and "temperature" in samplers:
|
||||
samplers.remove("temperature")
|
||||
samplers.append("temperature")
|
||||
|
||||
payload = {
|
||||
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
|
||||
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
|
||||
|
@ -97,7 +72,6 @@ class LlamaServer:
|
|||
"dry_base": state["dry_base"],
|
||||
"dry_allowed_length": state["dry_allowed_length"],
|
||||
"dry_penalty_last_n": state["repetition_penalty_range"],
|
||||
"dry_sequence_breakers": dry_sequence_breakers,
|
||||
"xtc_probability": state["xtc_probability"],
|
||||
"xtc_threshold": state["xtc_threshold"],
|
||||
"mirostat": state["mirostat_mode"],
|
||||
|
@ -106,20 +80,44 @@ class LlamaServer:
|
|||
"grammar": state["grammar_string"],
|
||||
"seed": state["seed"],
|
||||
"ignore_eos": state["ban_eos_token"],
|
||||
"samplers": samplers,
|
||||
}
|
||||
|
||||
# DRY
|
||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
payload["dry_sequence_breakers"] = dry_sequence_breakers
|
||||
|
||||
# Sampler order
|
||||
if state["sampler_priority"]:
|
||||
samplers = state["sampler_priority"]
|
||||
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||
filtered_samplers = []
|
||||
|
||||
penalty_found = False
|
||||
for s in samplers:
|
||||
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||
filtered_samplers.append(s.strip())
|
||||
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||
filtered_samplers.append("penalties")
|
||||
penalty_found = True
|
||||
|
||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||
if state["temperature_last"] and "temperature" in samplers:
|
||||
samplers.remove("temperature")
|
||||
samplers.append("temperature")
|
||||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||
payload["logit_bias"] = to_ban
|
||||
|
||||
return payload
|
||||
|
||||
def generate_with_streaming(
|
||||
self,
|
||||
prompt,
|
||||
state,
|
||||
):
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
payload = self.prepare_payload(state)
|
||||
|
||||
|
@ -178,6 +176,13 @@ class LlamaServer:
|
|||
print(f"Problematic line: {line}")
|
||||
continue
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||
"""Get the logits/probabilities for the next token after a prompt"""
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
|
@ -280,13 +285,17 @@ class LlamaServer:
|
|||
bufsize=1
|
||||
)
|
||||
|
||||
def filter_stderr():
|
||||
for line in iter(self.process.stderr.readline, ''):
|
||||
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
||||
sys.stderr.write(line)
|
||||
sys.stderr.flush()
|
||||
def filter_stderr(process_stderr):
|
||||
try:
|
||||
for line in iter(process_stderr.readline, ''):
|
||||
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
||||
sys.stderr.write(line)
|
||||
sys.stderr.flush()
|
||||
except (ValueError, IOError):
|
||||
# Handle pipe closed exceptions
|
||||
pass
|
||||
|
||||
threading.Thread(target=filter_stderr, daemon=True).start()
|
||||
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
|
||||
|
||||
# Wait for server to be healthy
|
||||
health_url = f"http://localhost:{self.port}/health"
|
||||
|
|
Loading…
Add table
Reference in a new issue