Merge pull request #6854 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2025-04-18 23:41:56 -03:00 committed by GitHub
commit 44a6d8a761
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -24,6 +24,7 @@ class LlamaServer:
self.server_path = server_path self.server_path = server_path
self.port = self._find_available_port() self.port = self._find_available_port()
self.process = None self.process = None
self.session = requests.Session()
self.vocabulary_size = None self.vocabulary_size = None
self.bos_token = "<s>" self.bos_token = "<s>"
@ -40,7 +41,7 @@ class LlamaServer:
"add_special": add_bos_token, "add_special": add_bos_token,
} }
response = requests.post(url, json=payload) response = self.session.post(url, json=payload)
result = response.json() result = response.json()
return result.get("tokens", []) return result.get("tokens", [])
@ -50,7 +51,7 @@ class LlamaServer:
"tokens": token_ids, "tokens": token_ids,
} }
response = requests.post(url, json=payload) response = self.session.post(url, json=payload)
result = response.json() result = response.json()
return result.get("content", "") return result.get("content", "")
@ -139,42 +140,41 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print() print()
# Make a direct request with streaming enabled # Make a direct request with streaming enabled using a context manager
response = requests.post(url, json=payload, stream=True) with self.session.post(url, json=payload, stream=True) as response:
response.raise_for_status() # Raise an exception for HTTP errors response.raise_for_status() # Raise an exception for HTTP errors
full_text = "" full_text = ""
# Process the streaming response # Process the streaming response
for line in response.iter_lines(): for line in response.iter_lines(decode_unicode=True):
if shared.stop_everything: if shared.stop_everything:
break break
if line: if line:
try: try:
# Check if the line starts with "data: " and remove it # Check if the line starts with "data: " and remove it
line_str = line.decode('utf-8') if line.startswith('data: '):
if line_str.startswith('data: '): line = line[6:] # Remove the "data: " prefix
line_str = line_str[6:] # Remove the "data: " prefix
# Parse the JSON data # Parse the JSON data
data = json.loads(line_str) data = json.loads(line)
# Extract the token content # Extract the token content
if 'content' in data: if 'content' in data:
token_text = data['content'] token_text = data['content']
full_text += token_text full_text += token_text
yield full_text yield full_text
# Check if generation is complete # Check if generation is complete
if data.get('stop', False): if data.get('stop', False):
break break
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Log the error and the problematic line # Log the error and the problematic line
print(f"JSON decode error: {e}") print(f"JSON decode error: {e}")
print(f"Problematic line: {line}") print(f"Problematic line: {line}")
continue continue
def generate(self, prompt, state): def generate(self, prompt, state):
output = "" output = ""
@ -203,7 +203,7 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print() print()
response = requests.post(url, json=payload) response = self.session.post(url, json=payload)
result = response.json() result = response.json()
if "completion_probabilities" in result: if "completion_probabilities" in result:
@ -217,7 +217,7 @@ class LlamaServer:
def _get_vocabulary_size(self): def _get_vocabulary_size(self):
"""Get and store the model's maximum context length.""" """Get and store the model's maximum context length."""
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url).json() response = self.session.get(url).json()
if "data" in response and len(response["data"]) > 0: if "data" in response and len(response["data"]) > 0:
model_info = response["data"][0] model_info = response["data"][0]
@ -227,7 +227,7 @@ class LlamaServer:
def _get_bos_token(self): def _get_bos_token(self):
"""Get and store the model's BOS token.""" """Get and store the model's BOS token."""
url = f"http://localhost:{self.port}/props" url = f"http://localhost:{self.port}/props"
response = requests.get(url).json() response = self.session.get(url).json()
if "bos_token" in response: if "bos_token" in response:
self.bos_token = response["bos_token"] self.bos_token = response["bos_token"]
@ -309,7 +309,7 @@ class LlamaServer:
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}") raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
try: try:
response = requests.get(health_url) response = self.session.get(health_url)
if response.status_code == 200: if response.status_code == 200:
break break
except: except: