diff --git a/llm_server/database/database.py b/llm_server/database/database.py index c3716a7..991d15b 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -145,7 +145,7 @@ def get_distinct_ips_24h(): conn = db_pool.connection() cursor = conn.cursor() try: - cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND token NOT LIKE 'SYSTEM__%%'", (past_24_hours,)) + cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) result = cursor.fetchone() return result[0] if result else 0 finally: diff --git a/llm_server/helpers.py b/llm_server/helpers.py index d70546b..1ba5db8 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -30,23 +30,10 @@ def safe_list_get(l, idx, default): def deep_sort(obj): - """ - https://stackoverflow.com/a/59218649 - :param obj: - :return: - """ if isinstance(obj, dict): - obj = OrderedDict(sorted(obj.items())) - for k, v in obj.items(): - if isinstance(v, dict) or isinstance(v, list): - obj[k] = deep_sort(v) - + return OrderedDict((k, deep_sort(v)) for k, v in sorted(obj.items())) if isinstance(obj, list): - for i, v in enumerate(obj): - if isinstance(v, dict) or isinstance(v, list): - obj[i] = deep_sort(v) - obj = sorted(obj, key=lambda x: json.dumps(x)) - + return sorted(deep_sort(x) for x in obj) return obj diff --git a/llm_server/opts.py b/llm_server/opts.py index 42fdcab..be84e5b 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -29,4 +29,4 @@ openai_api_key = None backend_request_timeout = 30 backend_generate_request_timeout = 95 admin_token = None -openai_epose_our_model = False +openai_expose_our_model = False diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index ae4e3cf..bd19254 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -1,13 +1,14 @@ -import json import sys import traceback +from typing import Union import redis as redis_pkg +import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import FieldT +from redis.typing import FieldT, ExpiryT -cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) +cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 @@ -30,26 +31,31 @@ class RedisWrapper: def _key(self, key): return f"{self.prefix}:{key}" - def set(self, key, value): - return self.redis.set(self._key(key), value) + def set(self, key, value, ex: Union[ExpiryT, None] = None): + return self.redis.set(self._key(key), value, ex=ex) - def get(self, key, dtype=None): + def get(self, key, dtype=None, default=None): """ - :param key: :param dtype: convert to this type :return: """ + d = self.redis.get(self._key(key)) if dtype and d: try: if dtype == str: return d.decode('utf-8') + if dtype in [dict, list]: + return json.loads(d.decode("utf-8")) else: return dtype(d) except: traceback.print_exc() - return d + if not d: + return default + else: + return d def incr(self, key, amount=1): return self.redis.incr(self._key(key), amount) @@ -66,11 +72,11 @@ class RedisWrapper: def sismember(self, key: str, value: str): return self.redis.sismember(self._key(key), value) - def set_dict(self, key, dict_value): - return self.set(self._key(key), json.dumps(dict_value)) + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): + return self.set(key, json.dumps(dict_value), ex=ex) def get_dict(self, key): - r = self.get(self._key(key)) + r = self.get(key) if not r: return dict() else: diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 4069475..567d380 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -50,7 +50,7 @@ def openai_chat_completions(): response = generator(msg_to_backend) r_headers = dict(request.headers) r_url = request.url - model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model') + model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model') def generate(): generated_text = '' diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 6735560..f9d8591 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -24,7 +24,7 @@ def openai_list_models(): else: oai = fetch_openai_models() r = [] - if opts.openai_epose_our_model: + if opts.openai_expose_our_model: r = [{ "object": "list", "data": [ diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 354c261..9377e22 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -150,7 +150,7 @@ def build_openai_response(prompt, response, model=None): "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), - "model": opts.running_model if opts.openai_epose_our_model else model, + "model": opts.running_model if opts.openai_expose_our_model else model, "choices": [{ "index": 0, "message": { diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 75cd373..8f4166f 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model from llm_server.netdata import get_power_states -from llm_server.routes.cache import cache, redis +from llm_server.routes.cache import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time @@ -35,8 +35,12 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act # TODO: have routes/__init__.py point to the latest API version generate_stats() -@cache.memoize(timeout=10) -def generate_stats(): +def generate_stats(regen: bool = False): + if not regen: + c = redis.get('proxy_stats', dict) + if c: + return c + model_name, error = get_running_model() # will return False when the fetch fails if isinstance(model_name, bool): online = False @@ -53,12 +57,10 @@ def generate_stats(): active_gen_workers = get_active_gen_workers() proompters_in_queue = len(priority_queue) - estimated_avg_tps = float(redis.get('estimated_avg_tps')) + estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) if opts.average_generation_time_mode == 'database': - average_generation_time = float(redis.get('average_generation_elapsed_sec')) - # average_output_tokens = float(redis.get('average_output_tokens')) - # average_generation_time_from_tps = (average_output_tokens / estimated_avg_tps) + average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) # What to use in our math that calculates the wait time. # We could use the average TPS but we don't know the exact TPS value, only @@ -85,13 +87,8 @@ def generate_stats(): else: netdata_stats = {} - x = redis.get('base_client_api') - base_client_api = x.decode() if x else None - del x - - x = redis.get('proompters_5_min') - proompters_5_min = int(x) if x else None - del x + base_client_api = redis.get('base_client_api', str) + proompters_5_min = redis.get('proompters_5_min', str) output = { 'stats': { @@ -131,4 +128,9 @@ def generate_stats(): }, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, } - return deep_sort(output) + result = deep_sort(output) + + # It may take a bit to get the base client API, so don't cache until then. + if base_client_api: + redis.set_dict('proxy_stats', result) # Cache with no expiry + return result diff --git a/llm_server/threads.py b/llm_server/threads.py index df5245d..8f796dc 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -67,5 +67,5 @@ class MainBackgroundThread(Thread): def cache_stats(): while True: - x = generate_stats() + generate_stats(regen=True) time.sleep(5) diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py old mode 100644 new mode 100755 diff --git a/server.py b/server.py index ace03a5..27885c1 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,5 @@ import os +import re import sys from pathlib import Path from threading import Thread @@ -93,13 +94,14 @@ opts.enable_streaming = config['enable_streaming'] opts.openai_api_key = config['openai_api_key'] openai.api_key = opts.openai_api_key opts.admin_token = config['admin_token'] -opts.openai_epose_our_model = config['openai_epose_our_model'] +opts.openai_expose_our_model = config['openai_epose_our_model'] -if opts.openai_epose_our_model and not opts.openai_api_key: +config["http_host"] = re.sub(r'http(?:s)?://', '', config["http_host"]) + +if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) - if config['http_host']: redis.set('http_host', config['http_host']) redis.set('base_client_api', f'{config["http_host"]}/{opts.frontend_api_client.strip("/")}') @@ -141,6 +143,10 @@ process_avg_gen_time_background_thread.start() MainBackgroundThread().start() SemaphoreCheckerThread().start() +# Cache the initial stats +print('Loading backend stats...') +generate_stats() + init_socketio(app) app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') @@ -161,7 +167,7 @@ stats_updater_thread.start() def home(): stats = generate_stats() - if not bool(redis.get('backend_online')) or not stats['online']: + if not stats['online']: running_model = estimated_wait_sec = 'offline' else: running_model = opts.running_model @@ -188,9 +194,7 @@ def home(): if opts.mode == 'vllm': mode_info = vllm_info - x = redis.get('base_client_api') - base_client_api = x.decode() if x else None - del x + base_client_api = redis.get('base_client_api', str) return render_template('home.html', llm_middleware_name=opts.llm_middleware_name,