fix homepage slowness, fix incorrect 24 hr prompters, fix redis wrapper,

This commit is contained in:
Cyberes 2023-09-25 17:20:21 -06:00
parent 52e6965b5e
commit 135bd743bb
11 changed files with 53 additions and 54 deletions

View File

@ -145,7 +145,7 @@ def get_distinct_ips_24h():
conn = db_pool.connection() conn = db_pool.connection()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND token NOT LIKE 'SYSTEM__%%'", (past_24_hours,)) cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
finally: finally:

View File

@ -30,23 +30,10 @@ def safe_list_get(l, idx, default):
def deep_sort(obj): def deep_sort(obj):
"""
https://stackoverflow.com/a/59218649
:param obj:
:return:
"""
if isinstance(obj, dict): if isinstance(obj, dict):
obj = OrderedDict(sorted(obj.items())) return OrderedDict((k, deep_sort(v)) for k, v in sorted(obj.items()))
for k, v in obj.items():
if isinstance(v, dict) or isinstance(v, list):
obj[k] = deep_sort(v)
if isinstance(obj, list): if isinstance(obj, list):
for i, v in enumerate(obj): return sorted(deep_sort(x) for x in obj)
if isinstance(v, dict) or isinstance(v, list):
obj[i] = deep_sort(v)
obj = sorted(obj, key=lambda x: json.dumps(x))
return obj return obj

View File

@ -29,4 +29,4 @@ openai_api_key = None
backend_request_timeout = 30 backend_request_timeout = 30
backend_generate_request_timeout = 95 backend_generate_request_timeout = 95
admin_token = None admin_token = None
openai_epose_our_model = False openai_expose_our_model = False

View File

@ -1,13 +1,14 @@
import json
import sys import sys
import traceback import traceback
from typing import Union
import redis as redis_pkg import redis as redis_pkg
import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import FieldT from redis.typing import FieldT, ExpiryT
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000 ONE_MONTH_SECONDS = 2678000
@ -30,26 +31,31 @@ class RedisWrapper:
def _key(self, key): def _key(self, key):
return f"{self.prefix}:{key}" return f"{self.prefix}:{key}"
def set(self, key, value): def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value) return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None): def get(self, key, dtype=None, default=None):
""" """
:param key: :param key:
:param dtype: convert to this type :param dtype: convert to this type
:return: :return:
""" """
d = self.redis.get(self._key(key)) d = self.redis.get(self._key(key))
if dtype and d: if dtype and d:
try: try:
if dtype == str: if dtype == str:
return d.decode('utf-8') return d.decode('utf-8')
if dtype in [dict, list]:
return json.loads(d.decode("utf-8"))
else: else:
return dtype(d) return dtype(d)
except: except:
traceback.print_exc() traceback.print_exc()
return d if not d:
return default
else:
return d
def incr(self, key, amount=1): def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount) return self.redis.incr(self._key(key), amount)
@ -66,11 +72,11 @@ class RedisWrapper:
def sismember(self, key: str, value: str): def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value) return self.redis.sismember(self._key(key), value)
def set_dict(self, key, dict_value): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(self._key(key), json.dumps(dict_value)) return self.set(key, json.dumps(dict_value), ex=ex)
def get_dict(self, key): def get_dict(self, key):
r = self.get(self._key(key)) r = self.get(key)
if not r: if not r:
return dict() return dict()
else: else:

View File

@ -50,7 +50,7 @@ def openai_chat_completions():
response = generator(msg_to_backend) response = generator(msg_to_backend)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model') model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model')
def generate(): def generate():
generated_text = '' generated_text = ''

View File

@ -24,7 +24,7 @@ def openai_list_models():
else: else:
oai = fetch_openai_models() oai = fetch_openai_models()
r = [] r = []
if opts.openai_epose_our_model: if opts.openai_expose_our_model:
r = [{ r = [{
"object": "list", "object": "list",
"data": [ "data": [

View File

@ -150,7 +150,7 @@ def build_openai_response(prompt, response, model=None):
"id": f"chatcmpl-{generate_oai_string(30)}", "id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": opts.running_model if opts.openai_epose_our_model else model, "model": opts.running_model if opts.openai_expose_our_model else model,
"choices": [{ "choices": [{
"index": 0, "index": 0,
"message": { "message": {

View File

@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states from llm_server.netdata import get_power_states
from llm_server.routes.cache import cache, redis from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
@ -35,8 +35,12 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
# TODO: have routes/__init__.py point to the latest API version generate_stats() # TODO: have routes/__init__.py point to the latest API version generate_stats()
@cache.memoize(timeout=10) def generate_stats(regen: bool = False):
def generate_stats(): if not regen:
c = redis.get('proxy_stats', dict)
if c:
return c
model_name, error = get_running_model() # will return False when the fetch fails model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool): if isinstance(model_name, bool):
online = False online = False
@ -53,12 +57,10 @@ def generate_stats():
active_gen_workers = get_active_gen_workers() active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue) proompters_in_queue = len(priority_queue)
estimated_avg_tps = float(redis.get('estimated_avg_tps')) estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
if opts.average_generation_time_mode == 'database': if opts.average_generation_time_mode == 'database':
average_generation_time = float(redis.get('average_generation_elapsed_sec')) average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
# average_output_tokens = float(redis.get('average_output_tokens'))
# average_generation_time_from_tps = (average_output_tokens / estimated_avg_tps)
# What to use in our math that calculates the wait time. # What to use in our math that calculates the wait time.
# We could use the average TPS but we don't know the exact TPS value, only # We could use the average TPS but we don't know the exact TPS value, only
@ -85,13 +87,8 @@ def generate_stats():
else: else:
netdata_stats = {} netdata_stats = {}
x = redis.get('base_client_api') base_client_api = redis.get('base_client_api', str)
base_client_api = x.decode() if x else None proompters_5_min = redis.get('proompters_5_min', str)
del x
x = redis.get('proompters_5_min')
proompters_5_min = int(x) if x else None
del x
output = { output = {
'stats': { 'stats': {
@ -131,4 +128,9 @@ def generate_stats():
}, },
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
} }
return deep_sort(output) result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.
if base_client_api:
redis.set_dict('proxy_stats', result) # Cache with no expiry
return result

View File

@ -67,5 +67,5 @@ class MainBackgroundThread(Thread):
def cache_stats(): def cache_stats():
while True: while True:
x = generate_stats() generate_stats(regen=True)
time.sleep(5) time.sleep(5)

0
other/vllm/vllm_api_server.py Normal file → Executable file
View File

View File

@ -1,4 +1,5 @@
import os import os
import re
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
@ -93,13 +94,14 @@ opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key'] opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token'] opts.admin_token = config['admin_token']
opts.openai_epose_our_model = config['openai_epose_our_model'] opts.openai_expose_our_model = config['openai_epose_our_model']
if opts.openai_epose_our_model and not opts.openai_api_key: config["http_host"] = re.sub(r'http(?:s)?://', '', config["http_host"])
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1) sys.exit(1)
if config['http_host']: if config['http_host']:
redis.set('http_host', config['http_host']) redis.set('http_host', config['http_host'])
redis.set('base_client_api', f'{config["http_host"]}/{opts.frontend_api_client.strip("/")}') redis.set('base_client_api', f'{config["http_host"]}/{opts.frontend_api_client.strip("/")}')
@ -141,6 +143,10 @@ process_avg_gen_time_background_thread.start()
MainBackgroundThread().start() MainBackgroundThread().start()
SemaphoreCheckerThread().start() SemaphoreCheckerThread().start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
init_socketio(app) init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
@ -161,7 +167,7 @@ stats_updater_thread.start()
def home(): def home():
stats = generate_stats() stats = generate_stats()
if not bool(redis.get('backend_online')) or not stats['online']: if not stats['online']:
running_model = estimated_wait_sec = 'offline' running_model = estimated_wait_sec = 'offline'
else: else:
running_model = opts.running_model running_model = opts.running_model
@ -188,9 +194,7 @@ def home():
if opts.mode == 'vllm': if opts.mode == 'vllm':
mode_info = vllm_info mode_info = vllm_info
x = redis.get('base_client_api') base_client_api = redis.get('base_client_api', str)
base_client_api = x.decode() if x else None
del x
return render_template('home.html', return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name, llm_middleware_name=opts.llm_middleware_name,