From 84a1fcfdd8204befa5a5ce6727759b0e2150f4f2 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sat, 23 Sep 2023 23:14:22 -0600 Subject: [PATCH] don't store host if it's an IP --- llm_server/database/database.py | 1 - llm_server/helpers.py | 17 ++++++++++++++++- llm_server/opts.py | 1 - llm_server/routes/cache.py | 20 ++++++++++++++++++-- llm_server/routes/helpers/client.py | 4 +++- llm_server/routes/openai/__init__.py | 6 ++---- llm_server/routes/request_handler.py | 6 ++---- server.py | 7 ++----- 8 files changed, 43 insertions(+), 19 deletions(-) diff --git a/llm_server/database/database.py b/llm_server/database/database.py index aa7a685..131e4dc 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -41,7 +41,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - conn.commit() finally: cursor.close() diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 693714e..e0a4f1e 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -1,10 +1,14 @@ -import simplejson as json import math +import re from collections import OrderedDict from pathlib import Path +import simplejson as json from flask import make_response +from llm_server import opts +from llm_server.routes.cache import redis + def resolve_path(*p: str): return Path(*p).expanduser().resolve().absolute() @@ -62,3 +66,14 @@ def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True): def round_up_base(n, base): return math.ceil(n / base) * base + + +def set_base_client_api(request): + http_host = redis.get('http_host', str) + host = request.headers.get("Host") + if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): + # If the current http_host is not an IP, don't do anything. + return + else: + redis.set('http_host', host) + redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}') diff --git a/llm_server/opts.py b/llm_server/opts.py index 72f441c..8f997a9 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -11,7 +11,6 @@ max_new_tokens = 500 auth_required = False log_prompts = False frontend_api_client = '' -http_host = None verify_ssl = True show_num_prompts = True show_uptime = True diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index 3af1005..daaf2da 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -1,5 +1,6 @@ import json import sys +import traceback import redis as redis_pkg from flask_caching import Cache @@ -32,8 +33,23 @@ class RedisWrapper: def set(self, key, value): return self.redis.set(self._key(key), value) - def get(self, key): - return self.redis.get(self._key(key)) + def get(self, key, dtype=None): + """ + + :param key: + :param dtype: convert to this type + :return: + """ + d = self.redis.get(self._key(key)) + if dtype and d: + try: + if dtype == str: + return d.decode('utf-8') + else: + return dtype(d) + except: + traceback.print_exc() + return d def incr(self, key, amount=1): return self.redis.incr(self._key(key), amount) diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index 18a5706..0b135a5 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,9 +1,11 @@ from llm_server import opts +from llm_server.routes.cache import redis def format_sillytavern_err(msg: str, level: str = 'info'): + http_host = redis.get('http_host') return f"""``` -=== MESSAGE FROM LLM MIDDLEWARE AT {opts.http_host} === +=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} === -> {level.upper()} <- {msg} ```""" diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index 1fe0349..4e1515f 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -6,6 +6,7 @@ from ..helpers.http import require_api_key from ..openai_request_handler import build_openai_response from ..server_error import handle_server_error from ... import opts +from ...helpers import set_base_client_api openai_bp = Blueprint('openai/v1/', __name__) @@ -13,12 +14,9 @@ openai_bp = Blueprint('openai/v1/', __name__) @openai_bp.before_request def before_oai_request(): # TODO: unify with normal before_request() - if not opts.http_host: - opts.http_host = request.headers.get("Host") + set_base_client_api(request) if not opts.enable_openi_compatible_backend: return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401 - if not redis.get('base_client_api'): - redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}') if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index ef3c203..da9b8ab 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -7,6 +7,7 @@ from flask import Response, request from llm_server import opts from llm_server.database.conn import db_pool from llm_server.database.database import log_prompt +from llm_server.helpers import set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis @@ -193,10 +194,7 @@ def delete_dict_key(d: dict, k: Union[str, list]): def before_request(): - if not opts.http_host: - opts.http_host = request.headers.get("Host") - if not redis.get('base_client_api'): - redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}') + set_base_client_api(request) if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: diff --git a/server.py b/server.py index 5653be0..a7f401c 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,7 @@ except ModuleNotFoundError as e: import config from llm_server import opts from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names -from llm_server.helpers import resolve_path +from llm_server.helpers import resolve_path, set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import cache, redis from llm_server.routes.queue import start_workers @@ -204,10 +204,7 @@ def server_error(e): @app.before_request def before_app_request(): - if not opts.http_host: - opts.http_host = request.headers.get("Host") - if not redis.get('base_client_api'): - redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}') + set_base_client_api(request) if __name__ == "__main__":