mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 06:06:20 -04:00
commit
44a6d8a761
1 changed files with 35 additions and 35 deletions
|
@ -24,6 +24,7 @@ class LlamaServer:
|
|||
self.server_path = server_path
|
||||
self.port = self._find_available_port()
|
||||
self.process = None
|
||||
self.session = requests.Session()
|
||||
self.vocabulary_size = None
|
||||
self.bos_token = "<s>"
|
||||
|
||||
|
@ -40,7 +41,7 @@ class LlamaServer:
|
|||
"add_special": add_bos_token,
|
||||
}
|
||||
|
||||
response = requests.post(url, json=payload)
|
||||
response = self.session.post(url, json=payload)
|
||||
result = response.json()
|
||||
return result.get("tokens", [])
|
||||
|
||||
|
@ -50,7 +51,7 @@ class LlamaServer:
|
|||
"tokens": token_ids,
|
||||
}
|
||||
|
||||
response = requests.post(url, json=payload)
|
||||
response = self.session.post(url, json=payload)
|
||||
result = response.json()
|
||||
return result.get("content", "")
|
||||
|
||||
|
@ -139,42 +140,41 @@ class LlamaServer:
|
|||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||
print()
|
||||
|
||||
# Make a direct request with streaming enabled
|
||||
response = requests.post(url, json=payload, stream=True)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
# Make a direct request with streaming enabled using a context manager
|
||||
with self.session.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
full_text = ""
|
||||
full_text = ""
|
||||
|
||||
# Process the streaming response
|
||||
for line in response.iter_lines():
|
||||
if shared.stop_everything:
|
||||
break
|
||||
# Process the streaming response
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
if line:
|
||||
try:
|
||||
# Check if the line starts with "data: " and remove it
|
||||
line_str = line.decode('utf-8')
|
||||
if line_str.startswith('data: '):
|
||||
line_str = line_str[6:] # Remove the "data: " prefix
|
||||
if line:
|
||||
try:
|
||||
# Check if the line starts with "data: " and remove it
|
||||
if line.startswith('data: '):
|
||||
line = line[6:] # Remove the "data: " prefix
|
||||
|
||||
# Parse the JSON data
|
||||
data = json.loads(line_str)
|
||||
# Parse the JSON data
|
||||
data = json.loads(line)
|
||||
|
||||
# Extract the token content
|
||||
if 'content' in data:
|
||||
token_text = data['content']
|
||||
full_text += token_text
|
||||
yield full_text
|
||||
# Extract the token content
|
||||
if 'content' in data:
|
||||
token_text = data['content']
|
||||
full_text += token_text
|
||||
yield full_text
|
||||
|
||||
# Check if generation is complete
|
||||
if data.get('stop', False):
|
||||
break
|
||||
# Check if generation is complete
|
||||
if data.get('stop', False):
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Log the error and the problematic line
|
||||
print(f"JSON decode error: {e}")
|
||||
print(f"Problematic line: {line}")
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
# Log the error and the problematic line
|
||||
print(f"JSON decode error: {e}")
|
||||
print(f"Problematic line: {line}")
|
||||
continue
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
|
@ -203,7 +203,7 @@ class LlamaServer:
|
|||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||
print()
|
||||
|
||||
response = requests.post(url, json=payload)
|
||||
response = self.session.post(url, json=payload)
|
||||
result = response.json()
|
||||
|
||||
if "completion_probabilities" in result:
|
||||
|
@ -217,7 +217,7 @@ class LlamaServer:
|
|||
def _get_vocabulary_size(self):
|
||||
"""Get and store the model's maximum context length."""
|
||||
url = f"http://localhost:{self.port}/v1/models"
|
||||
response = requests.get(url).json()
|
||||
response = self.session.get(url).json()
|
||||
|
||||
if "data" in response and len(response["data"]) > 0:
|
||||
model_info = response["data"][0]
|
||||
|
@ -227,7 +227,7 @@ class LlamaServer:
|
|||
def _get_bos_token(self):
|
||||
"""Get and store the model's BOS token."""
|
||||
url = f"http://localhost:{self.port}/props"
|
||||
response = requests.get(url).json()
|
||||
response = self.session.get(url).json()
|
||||
if "bos_token" in response:
|
||||
self.bos_token = response["bos_token"]
|
||||
|
||||
|
@ -309,7 +309,7 @@ class LlamaServer:
|
|||
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
|
||||
|
||||
try:
|
||||
response = requests.get(health_url)
|
||||
response = self.session.get(health_url)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except:
|
||||
|
|
Loading…
Add table
Reference in a new issue