don't store host if it's an IP

This commit is contained in:
Cyberes 2023-09-23 23:14:22 -06:00
parent 0015e653b2
commit 84a1fcfdd8
8 changed files with 43 additions and 19 deletions

View File

@ -41,7 +41,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
finally:
cursor.close()

View File

@ -1,10 +1,14 @@
import simplejson as json
import math
import re
from collections import OrderedDict
from pathlib import Path
import simplejson as json
from flask import make_response
from llm_server import opts
from llm_server.routes.cache import redis
def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute()
@ -62,3 +66,14 @@ def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True):
def round_up_base(n, base):
return math.ceil(n / base) * base
def set_base_client_api(request):
http_host = redis.get('http_host', str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.
return
else:
redis.set('http_host', host)
redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}')

View File

@ -11,7 +11,6 @@ max_new_tokens = 500
auth_required = False
log_prompts = False
frontend_api_client = ''
http_host = None
verify_ssl = True
show_num_prompts = True
show_uptime = True

View File

@ -1,5 +1,6 @@
import json
import sys
import traceback
import redis as redis_pkg
from flask_caching import Cache
@ -32,8 +33,23 @@ class RedisWrapper:
def set(self, key, value):
return self.redis.set(self._key(key), value)
def get(self, key):
return self.redis.get(self._key(key))
def get(self, key, dtype=None):
"""
:param key:
:param dtype: convert to this type
:return:
"""
d = self.redis.get(self._key(key))
if dtype and d:
try:
if dtype == str:
return d.decode('utf-8')
else:
return dtype(d)
except:
traceback.print_exc()
return d
def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount)

View File

@ -1,9 +1,11 @@
from llm_server import opts
from llm_server.routes.cache import redis
def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host')
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {opts.http_host} ===
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-
{msg}
```"""

View File

@ -6,6 +6,7 @@ from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response
from ..server_error import handle_server_error
from ... import opts
from ...helpers import set_base_client_api
openai_bp = Blueprint('openai/v1/', __name__)
@ -13,12 +14,9 @@ openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request
def before_oai_request():
# TODO: unify with normal before_request()
if not opts.http_host:
opts.http_host = request.headers.get("Host")
set_base_client_api(request)
if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:

View File

@ -7,6 +7,7 @@ from flask import Response, request
from llm_server import opts
from llm_server.database.conn import db_pool
from llm_server.database.database import log_prompt
from llm_server.helpers import set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
@ -193,10 +194,7 @@ def delete_dict_key(d: dict, k: Union[str, list]):
def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
set_base_client_api(request)
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:

View File

@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
import config
from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path
from llm_server.helpers import resolve_path, set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
@ -204,10 +204,7 @@ def server_error(e):
@app.before_request
def before_app_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
set_base_client_api(request)
if __name__ == "__main__":