diff --git a/llm_server/config.py b/llm_server/config.py index f379260..4e610a3 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -15,7 +15,8 @@ config_default_vars = { 'show_total_output_tokens': True, 'simultaneous_requests_per_ip': 3, 'show_backend_info': True, - 'max_new_tokens': 500 + 'max_new_tokens': 500, + 'manual_model_name': False } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 7302728..cf27e67 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -11,7 +11,7 @@ class LLMBackend: # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # raise NotImplementedError - def get_parameters(self, parameters) -> Union[dict, None]: + def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: """ Validate and return the parameters for this backend. Lets you set defaults for specific backends. diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 86e8129..5a57de6 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -78,7 +78,7 @@ class VLLMBackend(LLMBackend): # except Exception as e: # return False, e - def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]: + def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: default_params = SamplingParams() try: sampling_params = SamplingParams( @@ -91,8 +91,7 @@ class VLLMBackend(LLMBackend): max_tokens=parameters.get('max_new_tokens', default_params.max_tokens) ) except ValueError as e: - print(e) - return None, e + return None, str(e).strip('.') return vars(sampling_params), None # def transform_sampling_params(params: SamplingParams): diff --git a/llm_server/opts.py b/llm_server/opts.py index f1a5b0b..8708b83 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -22,3 +22,4 @@ show_total_output_tokens = True netdata_root = None simultaneous_requests_per_ip = 3 show_backend_info = True +manual_model_name = None diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 7c9a962..da3c238 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -56,10 +56,6 @@ class OobaRequestHandler: else: return self.request.remote_addr - # def get_parameters(self): - # # TODO: make this a LLMBackend method - # return self.backend.get_parameters() - def get_priority(self): if self.token: conn = sqlite3.connect(opts.database_path) @@ -85,22 +81,22 @@ class OobaRequestHandler: self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) def handle_request(self): + SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() + request_valid_json, self.request_json_body = validate_json(self.request.data) if not request_valid_json: return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 self.get_parameters() - - SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() - - request_valid, invalid_request_err_msg = self.validate_request() - if not self.parameters: - params_valid = False - else: + params_valid = False + request_valid = False + invalid_request_err_msg = None + if self.parameters: params_valid = True + request_valid, invalid_request_err_msg = self.validate_request() if not request_valid or not params_valid: - error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid] + error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg] combined_error_message = ', '.join(error_messages) err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) diff --git a/llm_server/routes/server_error.py b/llm_server/routes/server_error.py new file mode 100644 index 0000000..fec3836 --- /dev/null +++ b/llm_server/routes/server_error.py @@ -0,0 +1,3 @@ +def handle_server_error(e): + print(e) + return {'error': True}, 500 diff --git a/llm_server/routes/v1/__init__.py b/llm_server/routes/v1/__init__.py index 25ab2f6..e25fa82 100644 --- a/llm_server/routes/v1/__init__.py +++ b/llm_server/routes/v1/__init__.py @@ -1,6 +1,7 @@ from flask import Blueprint, request from ..helpers.http import require_api_key +from ..server_error import handle_server_error from ... import opts bp = Blueprint('v1', __name__) @@ -18,4 +19,9 @@ def before_request(): return response +@bp.errorhandler(500) +def handle_error(e): + return handle_server_error(e) + + from . import generate, info, proxy, generate_stream diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 0f42ccf..4b78bf8 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -9,7 +9,7 @@ from ... import opts @bp.route('/generate', methods=['POST']) def generate(): request_valid_json, request_json_body = validate_json(request.data) - if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')): + if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: handler = OobaRequestHandler(request) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 096c465..0eea965 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -94,7 +94,7 @@ def generate_stats(): 'gatekeeper': 'none' if opts.auth_required is False else 'token', 'context_size': opts.context_size, 'concurrent': opts.concurrent_gens, - 'model': model_name, + 'model': opts.manual_model_name if opts.manual_model_name else model_name, 'mode': opts.mode, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, }, diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 26bb2e3..883878e 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,6 +4,7 @@ from flask import jsonify, request from . import bp from ..cache import cache +from ... import opts from ...llm.info import get_running_model @@ -27,8 +28,8 @@ def get_model(): if cached_response: return cached_response - model, error = get_running_model() - if not model: + model_name, error = get_running_model() + if not model_name: response = jsonify({ 'code': 502, 'msg': 'failed to reach backend', @@ -36,7 +37,7 @@ def get_model(): }), 500 # return 500 so Cloudflare doesn't intercept us else: response = jsonify({ - 'result': model, + 'result': opts.manual_model_name if opts.manual_model_name else model_name, 'timestamp': int(time.time()) }), 200 cache.set(cache_key, response, timeout=60) diff --git a/server.py b/server.py index 9c2ac1f..e50e45a 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,8 @@ from threading import Thread from flask import Flask, jsonify, render_template, request +from llm_server.routes.server_error import handle_server_error + try: import vllm except ModuleNotFoundError as e: @@ -65,6 +67,7 @@ opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.show_backend_info = config['show_backend_info'] opts.max_new_tokens = config['max_new_tokens'] +opts.manual_model_name = config['manual_model_name'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: @@ -145,7 +148,7 @@ def home(): llm_middleware_name=config['llm_middleware_name'], analytics_tracking_code=analytics_tracking_code, info_html=info_html, - current_model=running_model, + current_model=opts.manual_model_name if opts.manual_model_name else running_model, client_api=stats['endpoints']['blocking'], ws_client_api=stats['endpoints']['streaming'], estimated_wait=estimated_wait_sec, @@ -169,8 +172,7 @@ def fallback(first=None, rest=None): @app.errorhandler(500) def server_error(e): - print(e) - return {'error': True}, 500 + return handle_server_error(e) if __name__ == "__main__":