mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 14:17:09 -04:00
commit
c19b995b8e
1 changed files with 31 additions and 28 deletions
|
@ -35,7 +35,7 @@ class LlamaServer:
|
||||||
if self.bos_token and text.startswith(self.bos_token):
|
if self.bos_token and text.startswith(self.bos_token):
|
||||||
add_bos_token = False
|
add_bos_token = False
|
||||||
|
|
||||||
url = f"http://localhost:{self.port}/tokenize"
|
url = f"http://127.0.0.1:{self.port}/tokenize"
|
||||||
payload = {
|
payload = {
|
||||||
"content": text,
|
"content": text,
|
||||||
"add_special": add_bos_token,
|
"add_special": add_bos_token,
|
||||||
|
@ -46,7 +46,7 @@ class LlamaServer:
|
||||||
return result.get("tokens", [])
|
return result.get("tokens", [])
|
||||||
|
|
||||||
def decode(self, token_ids, **kwargs):
|
def decode(self, token_ids, **kwargs):
|
||||||
url = f"http://localhost:{self.port}/detokenize"
|
url = f"http://127.0.0.1:{self.port}/detokenize"
|
||||||
payload = {
|
payload = {
|
||||||
"tokens": token_ids,
|
"tokens": token_ids,
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ class LlamaServer:
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://127.0.0.1:{self.port}/completion"
|
||||||
payload = self.prepare_payload(state)
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||||
|
@ -147,34 +147,37 @@ class LlamaServer:
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
|
||||||
# Process the streaming response
|
# Process the streaming response
|
||||||
for line in response.iter_lines(decode_unicode=True):
|
for line in response.iter_lines():
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
break
|
break
|
||||||
|
|
||||||
if line:
|
if not line:
|
||||||
try:
|
continue
|
||||||
# Check if the line starts with "data: " and remove it
|
|
||||||
if line.startswith('data: '):
|
|
||||||
line = line[6:] # Remove the "data: " prefix
|
|
||||||
|
|
||||||
# Parse the JSON data
|
try:
|
||||||
data = json.loads(line)
|
line = line.decode('utf-8')
|
||||||
|
|
||||||
# Extract the token content
|
# Check if the line starts with "data: " and remove it
|
||||||
if 'content' in data:
|
if line.startswith('data: '):
|
||||||
token_text = data['content']
|
line = line[6:] # Remove the "data: " prefix
|
||||||
full_text += token_text
|
|
||||||
yield full_text
|
|
||||||
|
|
||||||
# Check if generation is complete
|
# Parse the JSON data
|
||||||
if data.get('stop', False):
|
data = json.loads(line)
|
||||||
break
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
# Extract the token content
|
||||||
# Log the error and the problematic line
|
if data.get('content', ''):
|
||||||
print(f"JSON decode error: {e}")
|
full_text += data['content']
|
||||||
print(f"Problematic line: {line}")
|
yield full_text
|
||||||
continue
|
|
||||||
|
# Check if generation is complete
|
||||||
|
if data.get('stop', False):
|
||||||
|
break
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# Log the error and the problematic line
|
||||||
|
print(f"JSON decode error: {e}")
|
||||||
|
print(f"Problematic line: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
def generate(self, prompt, state):
|
||||||
output = ""
|
output = ""
|
||||||
|
@ -185,7 +188,7 @@ class LlamaServer:
|
||||||
|
|
||||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||||
"""Get the logits/probabilities for the next token after a prompt"""
|
"""Get the logits/probabilities for the next token after a prompt"""
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://127.0.0.1:{self.port}/completion"
|
||||||
|
|
||||||
payload = self.prepare_payload(state)
|
payload = self.prepare_payload(state)
|
||||||
payload.update({
|
payload.update({
|
||||||
|
@ -216,7 +219,7 @@ class LlamaServer:
|
||||||
|
|
||||||
def _get_vocabulary_size(self):
|
def _get_vocabulary_size(self):
|
||||||
"""Get and store the model's maximum context length."""
|
"""Get and store the model's maximum context length."""
|
||||||
url = f"http://localhost:{self.port}/v1/models"
|
url = f"http://127.0.0.1:{self.port}/v1/models"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
|
|
||||||
if "data" in response and len(response["data"]) > 0:
|
if "data" in response and len(response["data"]) > 0:
|
||||||
|
@ -226,7 +229,7 @@ class LlamaServer:
|
||||||
|
|
||||||
def _get_bos_token(self):
|
def _get_bos_token(self):
|
||||||
"""Get and store the model's BOS token."""
|
"""Get and store the model's BOS token."""
|
||||||
url = f"http://localhost:{self.port}/props"
|
url = f"http://127.0.0.1:{self.port}/props"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
if "bos_token" in response:
|
if "bos_token" in response:
|
||||||
self.bos_token = response["bos_token"]
|
self.bos_token = response["bos_token"]
|
||||||
|
@ -299,7 +302,7 @@ class LlamaServer:
|
||||||
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
|
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
|
||||||
|
|
||||||
# Wait for server to be healthy
|
# Wait for server to be healthy
|
||||||
health_url = f"http://localhost:{self.port}/health"
|
health_url = f"http://127.0.0.1:{self.port}/health"
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
timeout = 3600 * 8 # 8 hours
|
timeout = 3600 * 8 # 8 hours
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
|
Loading…
Add table
Reference in a new issue