mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-09 07:07:16 -04:00
commit
4fa52a1302
1 changed files with 47 additions and 38 deletions
|
@ -55,31 +55,6 @@ class LlamaServer:
|
||||||
return result.get("content", "")
|
return result.get("content", "")
|
||||||
|
|
||||||
def prepare_payload(self, state):
|
def prepare_payload(self, state):
|
||||||
# Prepare DRY
|
|
||||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
|
||||||
if not dry_sequence_breakers.startswith("["):
|
|
||||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
|
||||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
|
||||||
|
|
||||||
# Prepare the sampler order
|
|
||||||
samplers = state["sampler_priority"]
|
|
||||||
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
|
||||||
penalty_found = False
|
|
||||||
filtered_samplers = []
|
|
||||||
for s in samplers:
|
|
||||||
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
|
||||||
filtered_samplers.append(s.strip())
|
|
||||||
elif not penalty_found and s.strip() == "repetition_penalty":
|
|
||||||
filtered_samplers.append("penalties")
|
|
||||||
penalty_found = True
|
|
||||||
|
|
||||||
samplers = filtered_samplers
|
|
||||||
|
|
||||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
|
||||||
if state["temperature_last"] and "temperature" in samplers:
|
|
||||||
samplers.remove("temperature")
|
|
||||||
samplers.append("temperature")
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
|
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
|
||||||
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
|
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
|
||||||
|
@ -97,7 +72,6 @@ class LlamaServer:
|
||||||
"dry_base": state["dry_base"],
|
"dry_base": state["dry_base"],
|
||||||
"dry_allowed_length": state["dry_allowed_length"],
|
"dry_allowed_length": state["dry_allowed_length"],
|
||||||
"dry_penalty_last_n": state["repetition_penalty_range"],
|
"dry_penalty_last_n": state["repetition_penalty_range"],
|
||||||
"dry_sequence_breakers": dry_sequence_breakers,
|
|
||||||
"xtc_probability": state["xtc_probability"],
|
"xtc_probability": state["xtc_probability"],
|
||||||
"xtc_threshold": state["xtc_threshold"],
|
"xtc_threshold": state["xtc_threshold"],
|
||||||
"mirostat": state["mirostat_mode"],
|
"mirostat": state["mirostat_mode"],
|
||||||
|
@ -106,20 +80,44 @@ class LlamaServer:
|
||||||
"grammar": state["grammar_string"],
|
"grammar": state["grammar_string"],
|
||||||
"seed": state["seed"],
|
"seed": state["seed"],
|
||||||
"ignore_eos": state["ban_eos_token"],
|
"ignore_eos": state["ban_eos_token"],
|
||||||
"samplers": samplers,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# DRY
|
||||||
|
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||||
|
if not dry_sequence_breakers.startswith("["):
|
||||||
|
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||||
|
|
||||||
|
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||||
|
payload["dry_sequence_breakers"] = dry_sequence_breakers
|
||||||
|
|
||||||
|
# Sampler order
|
||||||
|
if state["sampler_priority"]:
|
||||||
|
samplers = state["sampler_priority"]
|
||||||
|
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||||
|
filtered_samplers = []
|
||||||
|
|
||||||
|
penalty_found = False
|
||||||
|
for s in samplers:
|
||||||
|
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||||
|
filtered_samplers.append(s.strip())
|
||||||
|
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||||
|
filtered_samplers.append("penalties")
|
||||||
|
penalty_found = True
|
||||||
|
|
||||||
|
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||||
|
if state["temperature_last"] and "temperature" in samplers:
|
||||||
|
samplers.remove("temperature")
|
||||||
|
samplers.append("temperature")
|
||||||
|
|
||||||
|
payload["samplers"] = filtered_samplers
|
||||||
|
|
||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||||
payload["logit_bias"] = to_ban
|
payload["logit_bias"] = to_ban
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def generate_with_streaming(
|
def generate_with_streaming(self, prompt, state):
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
state,
|
|
||||||
):
|
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://localhost:{self.port}/completion"
|
||||||
payload = self.prepare_payload(state)
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
|
@ -178,6 +176,13 @@ class LlamaServer:
|
||||||
print(f"Problematic line: {line}")
|
print(f"Problematic line: {line}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
def generate(self, prompt, state):
|
||||||
|
output = ""
|
||||||
|
for output in self.generate_with_streaming(prompt, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||||
"""Get the logits/probabilities for the next token after a prompt"""
|
"""Get the logits/probabilities for the next token after a prompt"""
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://localhost:{self.port}/completion"
|
||||||
|
@ -280,13 +285,17 @@ class LlamaServer:
|
||||||
bufsize=1
|
bufsize=1
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter_stderr():
|
def filter_stderr(process_stderr):
|
||||||
for line in iter(self.process.stderr.readline, ''):
|
try:
|
||||||
|
for line in iter(process_stderr.readline, ''):
|
||||||
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
||||||
sys.stderr.write(line)
|
sys.stderr.write(line)
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
except (ValueError, IOError):
|
||||||
|
# Handle pipe closed exceptions
|
||||||
|
pass
|
||||||
|
|
||||||
threading.Thread(target=filter_stderr, daemon=True).start()
|
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
|
||||||
|
|
||||||
# Wait for server to be healthy
|
# Wait for server to be healthy
|
||||||
health_url = f"http://localhost:{self.port}/health"
|
health_url = f"http://localhost:{self.port}/health"
|
||||||
|
|
Loading…
Add table
Reference in a new issue