don't store host if it's an IP

This commit is contained in:
Cyberes 2023-09-23 23:14:22 -06:00
parent 0015e653b2
commit 84a1fcfdd8
8 changed files with 43 additions and 19 deletions

View File

@ -41,7 +41,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
finally: finally:
cursor.close() cursor.close()

View File

@ -1,10 +1,14 @@
import simplejson as json
import math import math
import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
import simplejson as json
from flask import make_response from flask import make_response
from llm_server import opts
from llm_server.routes.cache import redis
def resolve_path(*p: str): def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute() return Path(*p).expanduser().resolve().absolute()
@ -62,3 +66,14 @@ def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True):
def round_up_base(n, base): def round_up_base(n, base):
return math.ceil(n / base) * base return math.ceil(n / base) * base
def set_base_client_api(request):
http_host = redis.get('http_host', str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.
return
else:
redis.set('http_host', host)
redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}')

View File

@ -11,7 +11,6 @@ max_new_tokens = 500
auth_required = False auth_required = False
log_prompts = False log_prompts = False
frontend_api_client = '' frontend_api_client = ''
http_host = None
verify_ssl = True verify_ssl = True
show_num_prompts = True show_num_prompts = True
show_uptime = True show_uptime = True

View File

@ -1,5 +1,6 @@
import json import json
import sys import sys
import traceback
import redis as redis_pkg import redis as redis_pkg
from flask_caching import Cache from flask_caching import Cache
@ -32,8 +33,23 @@ class RedisWrapper:
def set(self, key, value): def set(self, key, value):
return self.redis.set(self._key(key), value) return self.redis.set(self._key(key), value)
def get(self, key): def get(self, key, dtype=None):
return self.redis.get(self._key(key)) """
:param key:
:param dtype: convert to this type
:return:
"""
d = self.redis.get(self._key(key))
if dtype and d:
try:
if dtype == str:
return d.decode('utf-8')
else:
return dtype(d)
except:
traceback.print_exc()
return d
def incr(self, key, amount=1): def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount) return self.redis.incr(self._key(key), amount)

View File

@ -1,9 +1,11 @@
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis
def format_sillytavern_err(msg: str, level: str = 'info'): def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host')
return f"""``` return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {opts.http_host} === === MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <- -> {level.upper()} <-
{msg} {msg}
```""" ```"""

View File

@ -6,6 +6,7 @@ from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response from ..openai_request_handler import build_openai_response
from ..server_error import handle_server_error from ..server_error import handle_server_error
from ... import opts from ... import opts
from ...helpers import set_base_client_api
openai_bp = Blueprint('openai/v1/', __name__) openai_bp = Blueprint('openai/v1/', __name__)
@ -13,12 +14,9 @@ openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request @openai_bp.before_request
def before_oai_request(): def before_oai_request():
# TODO: unify with normal before_request() # TODO: unify with normal before_request()
if not opts.http_host: set_base_client_api(request)
opts.http_host = request.headers.get("Host")
if not opts.enable_openi_compatible_backend: if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401 return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if request.endpoint != 'v1.get_stats': if request.endpoint != 'v1.get_stats':
response = require_api_key() response = require_api_key()
if response is not None: if response is not None:

View File

@ -7,6 +7,7 @@ from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.database.conn import db_pool from llm_server.database.conn import db_pool
from llm_server.database.database import log_prompt from llm_server.database.database import log_prompt
from llm_server.helpers import set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -193,10 +194,7 @@ def delete_dict_key(d: dict, k: Union[str, list]):
def before_request(): def before_request():
if not opts.http_host: set_base_client_api(request)
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if request.endpoint != 'v1.get_stats': if request.endpoint != 'v1.get_stats':
response = require_api_key() response = require_api_key()
if response is not None: if response is not None:

View File

@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
import config import config
from llm_server import opts from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path, set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
@ -204,10 +204,7 @@ def server_error(e):
@app.before_request @app.before_request
def before_app_request(): def before_app_request():
if not opts.http_host: set_base_client_api(request)
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if __name__ == "__main__": if __name__ == "__main__":