fix exception when not valid model

This commit is contained in:
Cyberes 2023-10-05 12:28:00 -06:00
parent acf409abfc
commit 08df52a4fd
4 changed files with 36 additions and 27 deletions

View File

@ -24,7 +24,6 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
pass
prompt_tokens = get_token_count(prompt, backend_url)
print('starting')
if not is_error:
if not response_tokens:

View File

@ -34,11 +34,12 @@ class RequestHandler:
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
self.parameters = None
self.used = False
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model'):

View File

@ -9,7 +9,6 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...sock import sock
@ -79,6 +78,17 @@ def do_stream(ws, model_name):
return auth_failure
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
print(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': msg
}))
return
assert not handler.offline
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends

View File

@ -25,4 +25,3 @@ def db_logger():
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
print('finished log')