diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 5ed7a26..ed0847c 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -3,3 +3,17 @@ from pathlib import Path def resolve_path(*p: str): return Path(*p).expanduser().resolve().absolute() + + +def safe_list_get(l, idx, default): + """ + https://stackoverflow.com/a/5125636 + :param l: + :param idx: + :param default: + :return: + """ + try: + return l[idx] + except IndexError: + return default diff --git a/llm_server/opts.py b/llm_server/opts.py index 2ae99fd..f5be76a 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -9,3 +9,4 @@ database_path = './proxy-server.db' auth_required = False log_prompts = False frontend_api_client = '' +http_host = None diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py new file mode 100644 index 0000000..18a5706 --- /dev/null +++ b/llm_server/routes/helpers/client.py @@ -0,0 +1,9 @@ +from llm_server import opts + + +def format_sillytavern_err(msg: str, level: str = 'info'): + return f"""``` +=== MESSAGE FROM LLM MIDDLEWARE AT {opts.http_host} === +-> {level.upper()} <- +{msg} +```""" diff --git a/llm_server/routes/v1/__init__.py b/llm_server/routes/v1/__init__.py index f696082..0351270 100644 --- a/llm_server/routes/v1/__init__.py +++ b/llm_server/routes/v1/__init__.py @@ -1,6 +1,7 @@ from flask import Blueprint, request from ..helpers.http import require_api_key +from ... import opts bp = Blueprint('v1', __name__) @@ -9,6 +10,8 @@ bp = Blueprint('v1', __name__) @bp.before_request def before_request(): + if not opts.http_host: + opts.http_host = request.headers.get("Host") if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 36fac22..58f56a8 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -1,10 +1,12 @@ from flask import jsonify, request -from . import bp from llm_server.routes.stats import concurrent_semaphore, proompts +from . import bp +from ..helpers.client import format_sillytavern_err from ..helpers.http import cache_control, validate_json from ... import opts from ...database import log_prompt +from ...helpers import safe_list_get def generator(request_json_body): @@ -26,40 +28,72 @@ def generate(): return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 with concurrent_semaphore: + if request.headers.get('cf-connecting-ip'): + client_ip = request.headers.get('cf-connecting-ip') + elif request.headers.get('x-forwarded-for'): + client_ip = request.headers.get('x-forwarded-for') + else: + client_ip = request.remote_addr + token = request.headers.get('X-Api-Key') + + parameters = request_json_body.copy() + del parameters['prompt'] + success, response, error_msg = generator(request_json_body) if not success: + if opts.mode == 'oobabooga': + backend_response = format_sillytavern_err(f'Failed to reach the backend: {error_msg}', 'error') + response_json_body = { + 'results': [ + { + 'text': backend_response, + } + ], + } + else: + raise Exception + log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, - 'error': 'failed to reach backend' - }), 500 + 'error': 'failed to reach backend', + **response_json_body + }), 200 response_valid_json, response_json_body = validate_json(response) if response_valid_json: proompts.increment() + backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') + if not backend_response: + if opts.mode == 'oobabooga': + backend_response = format_sillytavern_err(f'Backend returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.token_limit}.', 'error') + response_json_body['results'][0]['text'] = backend_response + else: + raise Exception - # request.headers = {"host": "proxy.chub-archive.evulid.cc", "x-forwarded-proto": "https", "user-agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)", "cf-visitor": {"scheme": "https"}, "cf-ipcountry": "CH", "accept": "*/*", "accept-encoding": "gzip", - # "x-forwarded-for": "193.32.127.228", "cf-ray": "7fa72c6a6d5cbba7-FRA", "cf-connecting-ip": "193.32.127.228", "cdn-loop": "cloudflare", "content-type": "application/json", "content-length": "9039"} + log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) - if request.headers.get('cf-connecting-ip'): - client_ip = request.headers.get('cf-connecting-ip') - elif request.headers.get('x-forwarded-for'): - client_ip = request.headers.get('x-forwarded-for') - else: - client_ip = request.remote_addr + print(response_json_body) - parameters = request_json_body.copy() - del parameters['prompt'] - - token = request.headers.get('X-Api-Key') - - log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code) return jsonify({ **response_json_body }), 200 else: + if opts.mode == 'oobabooga': + backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') + response_json_body = { + 'results': [ + { + 'text': backend_response, + } + ], + } + else: + raise Exception + log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, - 'error': 'failed to reach backend' - }), 500 + 'error': 'the backend did not return valid JSON', + **response_json_body + }), 200 # @openai_bp.route('/chat/completions', methods=['POST']) # def generate_openai():