Merge pull request #6852 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2025-04-18 22:15:40 -03:00 committed by GitHub
commit 4fa52a1302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -55,31 +55,6 @@ class LlamaServer:
return result.get("content", "")
def prepare_payload(self, state):
# Prepare DRY
dry_sequence_breakers = state['dry_sequence_breakers']
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
dry_sequence_breakers = json.loads(dry_sequence_breakers)
# Prepare the sampler order
samplers = state["sampler_priority"]
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
penalty_found = False
filtered_samplers = []
for s in samplers:
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
filtered_samplers.append(s.strip())
elif not penalty_found and s.strip() == "repetition_penalty":
filtered_samplers.append("penalties")
penalty_found = True
samplers = filtered_samplers
# Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in samplers:
samplers.remove("temperature")
samplers.append("temperature")
payload = {
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
@ -97,7 +72,6 @@ class LlamaServer:
"dry_base": state["dry_base"],
"dry_allowed_length": state["dry_allowed_length"],
"dry_penalty_last_n": state["repetition_penalty_range"],
"dry_sequence_breakers": dry_sequence_breakers,
"xtc_probability": state["xtc_probability"],
"xtc_threshold": state["xtc_threshold"],
"mirostat": state["mirostat_mode"],
@ -106,20 +80,44 @@ class LlamaServer:
"grammar": state["grammar_string"],
"seed": state["seed"],
"ignore_eos": state["ban_eos_token"],
"samplers": samplers,
}
# DRY
dry_sequence_breakers = state['dry_sequence_breakers']
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
dry_sequence_breakers = json.loads(dry_sequence_breakers)
payload["dry_sequence_breakers"] = dry_sequence_breakers
# Sampler order
if state["sampler_priority"]:
samplers = state["sampler_priority"]
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
filtered_samplers = []
penalty_found = False
for s in samplers:
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
filtered_samplers.append(s.strip())
elif not penalty_found and s.strip() == "repetition_penalty":
filtered_samplers.append("penalties")
penalty_found = True
# Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in samplers:
samplers.remove("temperature")
samplers.append("temperature")
payload["samplers"] = filtered_samplers
if state['custom_token_bans']:
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
return payload
def generate_with_streaming(
self,
prompt,
state,
):
def generate_with_streaming(self, prompt, state):
url = f"http://localhost:{self.port}/completion"
payload = self.prepare_payload(state)
@ -178,6 +176,13 @@ class LlamaServer:
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state):
output = ""
for output in self.generate_with_streaming(prompt, state):
pass
return output
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
"""Get the logits/probabilities for the next token after a prompt"""
url = f"http://localhost:{self.port}/completion"
@ -280,13 +285,17 @@ class LlamaServer:
bufsize=1
)
def filter_stderr():
for line in iter(self.process.stderr.readline, ''):
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
sys.stderr.write(line)
sys.stderr.flush()
def filter_stderr(process_stderr):
try:
for line in iter(process_stderr.readline, ''):
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
sys.stderr.write(line)
sys.stderr.flush()
except (ValueError, IOError):
# Handle pipe closed exceptions
pass
threading.Thread(target=filter_stderr, daemon=True).start()
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
# Wait for server to be healthy
health_url = f"http://localhost:{self.port}/health"