Merge pull request #6857 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2025-04-19 21:45:55 -03:00 committed by GitHub
commit c19b995b8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,7 +35,7 @@ class LlamaServer:
if self.bos_token and text.startswith(self.bos_token): if self.bos_token and text.startswith(self.bos_token):
add_bos_token = False add_bos_token = False
url = f"http://localhost:{self.port}/tokenize" url = f"http://127.0.0.1:{self.port}/tokenize"
payload = { payload = {
"content": text, "content": text,
"add_special": add_bos_token, "add_special": add_bos_token,
@ -46,7 +46,7 @@ class LlamaServer:
return result.get("tokens", []) return result.get("tokens", [])
def decode(self, token_ids, **kwargs): def decode(self, token_ids, **kwargs):
url = f"http://localhost:{self.port}/detokenize" url = f"http://127.0.0.1:{self.port}/detokenize"
payload = { payload = {
"tokens": token_ids, "tokens": token_ids,
} }
@ -119,7 +119,7 @@ class LlamaServer:
return payload return payload
def generate_with_streaming(self, prompt, state): def generate_with_streaming(self, prompt, state):
url = f"http://localhost:{self.port}/completion" url = f"http://127.0.0.1:{self.port}/completion"
payload = self.prepare_payload(state) payload = self.prepare_payload(state)
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
@ -147,34 +147,37 @@ class LlamaServer:
full_text = "" full_text = ""
# Process the streaming response # Process the streaming response
for line in response.iter_lines(decode_unicode=True): for line in response.iter_lines():
if shared.stop_everything: if shared.stop_everything:
break break
if line: if not line:
try: continue
# Check if the line starts with "data: " and remove it
if line.startswith('data: '):
line = line[6:] # Remove the "data: " prefix
# Parse the JSON data try:
data = json.loads(line) line = line.decode('utf-8')
# Extract the token content # Check if the line starts with "data: " and remove it
if 'content' in data: if line.startswith('data: '):
token_text = data['content'] line = line[6:] # Remove the "data: " prefix
full_text += token_text
yield full_text
# Check if generation is complete # Parse the JSON data
if data.get('stop', False): data = json.loads(line)
break
except json.JSONDecodeError as e: # Extract the token content
# Log the error and the problematic line if data.get('content', ''):
print(f"JSON decode error: {e}") full_text += data['content']
print(f"Problematic line: {line}") yield full_text
continue
# Check if generation is complete
if data.get('stop', False):
break
except json.JSONDecodeError as e:
# Log the error and the problematic line
print(f"JSON decode error: {e}")
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state): def generate(self, prompt, state):
output = "" output = ""
@ -185,7 +188,7 @@ class LlamaServer:
def get_logits(self, prompt, state, n_probs=128, use_samplers=False): def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
"""Get the logits/probabilities for the next token after a prompt""" """Get the logits/probabilities for the next token after a prompt"""
url = f"http://localhost:{self.port}/completion" url = f"http://127.0.0.1:{self.port}/completion"
payload = self.prepare_payload(state) payload = self.prepare_payload(state)
payload.update({ payload.update({
@ -216,7 +219,7 @@ class LlamaServer:
def _get_vocabulary_size(self): def _get_vocabulary_size(self):
"""Get and store the model's maximum context length.""" """Get and store the model's maximum context length."""
url = f"http://localhost:{self.port}/v1/models" url = f"http://127.0.0.1:{self.port}/v1/models"
response = self.session.get(url).json() response = self.session.get(url).json()
if "data" in response and len(response["data"]) > 0: if "data" in response and len(response["data"]) > 0:
@ -226,7 +229,7 @@ class LlamaServer:
def _get_bos_token(self): def _get_bos_token(self):
"""Get and store the model's BOS token.""" """Get and store the model's BOS token."""
url = f"http://localhost:{self.port}/props" url = f"http://127.0.0.1:{self.port}/props"
response = self.session.get(url).json() response = self.session.get(url).json()
if "bos_token" in response: if "bos_token" in response:
self.bos_token = response["bos_token"] self.bos_token = response["bos_token"]
@ -299,7 +302,7 @@ class LlamaServer:
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start() threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
# Wait for server to be healthy # Wait for server to be healthy
health_url = f"http://localhost:{self.port}/health" health_url = f"http://127.0.0.1:{self.port}/health"
start_time = time.time() start_time = time.time()
timeout = 3600 * 8 # 8 hours timeout = 3600 * 8 # 8 hours
while time.time() - start_time < timeout: while time.time() - start_time < timeout: