diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 909848e..6944d57 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -15,6 +15,10 @@ class OobaRequestHandler(RequestHandler): def handle_request(self, return_ok: bool = True): assert not self.used + if self.offline: + msg = f'{self.selected_model} is not a valid model choice.' + print(msg) + return jsonify({'results': [{'text': format_sillytavern_err(msg)}]}), 200 request_valid, invalid_response = self.validate_request() if not request_valid: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 1361fb4..fb02816 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -13,6 +13,7 @@ from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token +from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue @@ -42,6 +43,11 @@ class RequestHandler: if not self.cluster_backend_info.get('model'): print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) + if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model'): + self.offline = True + else: + self.offline = False + self.selected_model = self.cluster_backend_info['model'] self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.parameters = None @@ -215,8 +221,11 @@ class RequestHandler: def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child. - # if self.used: - # raise Exception('Can only use a RequestHandler object once.') + # assert not self.used + # if self.offline: + # msg = f'{self.selected_model} is not a valid model choice.' + # print(msg) + # return format_sillytavern_err(msg) raise NotImplementedError def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]: