fix exception when not valid model

This commit is contained in:
Cyberes 2023-10-05 12:28:00 -06:00
parent acf409abfc
commit 08df52a4fd
4 changed files with 36 additions and 27 deletions

View File

@ -24,7 +24,6 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
pass
prompt_tokens = get_token_count(prompt, backend_url)
print('starting')
if not is_error:
if not response_tokens:

View File

@ -34,11 +34,12 @@ class RequestHandler:
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
self.parameters = None
self.used = False
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model'):

View File

@ -9,7 +9,6 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...sock import sock
@ -47,18 +46,18 @@ def do_stream(ws, model_name):
'message_num': 1
}))
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=None,
is_error=True
)
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=None,
is_error=True
)
if not opts.enable_streaming:
return 'Streaming is disabled', 500
@ -79,6 +78,17 @@ def do_stream(ws, model_name):
return auth_failure
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
print(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': msg
}))
return
assert not handler.offline
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
@ -199,16 +209,16 @@ def do_stream(ws, model_name):
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url
)
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url
)
finally:
try:
# Must close the connection or greenlets will complain.

View File

@ -25,4 +25,3 @@ def db_logger():
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
print('finished log')