Rename get_max_context_length to get_vocabulary_size in the new llama.cpp loader

This commit is contained in:
oobabooga 2025-04-18 08:14:15 -07:00
parent c1cc65e82e
commit d00d713ace

View file

@ -24,7 +24,7 @@ class LlamaServer:
self.server_path = server_path self.server_path = server_path
self.port = self._find_available_port() self.port = self._find_available_port()
self.process = None self.process = None
self.max_context_length = None self.vocabulary_size = None
self.bos_token = "<s>" self.bos_token = "<s>"
# Start the server # Start the server
@ -209,7 +209,7 @@ class LlamaServer:
else: else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
def _get_max_context_length(self): def _get_vocabulary_size(self):
"""Get and store the model's maximum context length.""" """Get and store the model's maximum context length."""
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url).json() response = requests.get(url).json()
@ -217,7 +217,7 @@ class LlamaServer:
if "data" in response and len(response["data"]) > 0: if "data" in response and len(response["data"]) > 0:
model_info = response["data"][0] model_info = response["data"][0]
if "meta" in model_info and "n_vocab" in model_info["meta"]: if "meta" in model_info and "n_vocab" in model_info["meta"]:
self.max_context_length = model_info["meta"]["n_vocab"] self.vocabulary_size = model_info["meta"]["n_vocab"]
def _get_bos_token(self): def _get_bos_token(self):
"""Get and store the model's BOS token.""" """Get and store the model's BOS token."""
@ -311,7 +311,7 @@ class LlamaServer:
raise TimeoutError(f"Server health check timed out after {timeout} seconds") raise TimeoutError(f"Server health check timed out after {timeout} seconds")
# Server is now healthy, get model info # Server is now healthy, get model info
self._get_max_context_length() self._get_vocabulary_size()
self._get_bos_token() self._get_bos_token()
return self.port return self.port