From 08df52a4fd16f3cff19c8c63cc83b0feed3b5931 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 5 Oct 2023 12:28:00 -0600 Subject: [PATCH] fix exception when not valid model --- llm_server/database/database.py | 1 - llm_server/routes/request_handler.py | 5 ++- llm_server/routes/v1/generate_stream.py | 56 +++++++++++++++---------- llm_server/workers/logger.py | 1 - 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/llm_server/database/database.py b/llm_server/database/database.py index fc800a2..d6bd6b2 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -24,7 +24,6 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_ pass prompt_tokens = get_token_count(prompt, backend_url) - print('starting') if not is_error: if not response_tokens: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 90da0b1..ef5aa34 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -34,11 +34,12 @@ class RequestHandler: self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token) - self.backend_url = get_a_cluster_backend(selected_model) - self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.parameters = None self.used = False + self.backend_url = get_a_cluster_backend(selected_model) + self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + if not self.cluster_backend_info.get('mode'): print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) if not self.cluster_backend_info.get('model'): diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 9962ff8..55fceb9 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -9,7 +9,6 @@ from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts -from ...database.database import do_db_log from ...database.log_to_db import log_to_db from ...llm.generator import generator from ...sock import sock @@ -47,18 +46,18 @@ def do_stream(ws, model_name): 'message_num': 1 })) log_to_db(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=quitting_err_msg, - gen_time=None, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.cluster_backend_info, - response_tokens=None, - is_error=True - ) + token=handler.token, + prompt=input_prompt, + response=quitting_err_msg, + gen_time=None, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.cluster_backend_info, + response_tokens=None, + is_error=True + ) if not opts.enable_streaming: return 'Streaming is disabled', 500 @@ -79,6 +78,17 @@ def do_stream(ws, model_name): return auth_failure handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) + if handler.offline: + msg = f'{handler.selected_model} is not a valid model choice.' + print(msg) + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': 0, + 'text': msg + })) + return + + assert not handler.offline if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends @@ -199,16 +209,16 @@ def do_stream(ws, model_name): end_time = time.time() elapsed_time = end_time - start_time log_to_db(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=elapsed_time, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url - ) + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url + ) finally: try: # Must close the connection or greenlets will complain. diff --git a/llm_server/workers/logger.py b/llm_server/workers/logger.py index 2707615..cb8dc01 100644 --- a/llm_server/workers/logger.py +++ b/llm_server/workers/logger.py @@ -25,4 +25,3 @@ def db_logger(): if function_name == 'log_prompt': do_db_log(*args, **kwargs) - print('finished log')