From 11a0b6541f1fc61ac762824015045edf2b1ab0b7 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 23 Aug 2023 22:01:06 -0600 Subject: [PATCH] fix some stuff related to gunicorn workers --- config/config.yml | 3 +++ llm_server/database.py | 19 ++++++++++++----- llm_server/routes/cache.py | 35 +++++++++++++++++++++++++++++++- llm_server/routes/queue.py | 7 +++++++ llm_server/routes/stats.py | 11 +++++++++- llm_server/routes/v1/generate.py | 6 +++--- llm_server/routes/v1/proxy.py | 12 +++++------ server.py | 14 +++++++++---- 8 files changed, 87 insertions(+), 20 deletions(-) diff --git a/config/config.yml b/config/config.yml index e1a4a13..b40e7bb 100644 --- a/config/config.yml +++ b/config/config.yml @@ -10,6 +10,9 @@ token_limit: 7777 backend_url: https://10.0.0.86:8083 +# Load the number of prompts from the database to display on the stats page. +load_num_prompts: true + # Path that is shown to users for them to connect to frontend_api_client: /api diff --git a/llm_server/database.py b/llm_server/database.py index 9040074..52f74b8 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -10,9 +10,9 @@ from llm_server import opts tokenizer = tiktoken.get_encoding("cl100k_base") -def init_db(db_path): - if not Path(db_path).exists(): - conn = sqlite3.connect(db_path) +def init_db(): + if not Path(opts.database_path).exists(): + conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute(''' CREATE TABLE prompts ( @@ -43,7 +43,7 @@ def init_db(db_path): conn.close() -def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code): +def log_prompt(ip, token, prompt, response, parameters, headers, backend_response_code): prompt_tokens = len(tokenizer.encode(prompt)) response_tokens = len(tokenizer.encode(response)) @@ -51,7 +51,7 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backen prompt = response = None timestamp = int(time.time()) - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) @@ -82,3 +82,12 @@ def increment_uses(api_key): conn.commit() return True return False + + +def get_number_of_rows(table_name): + conn = sqlite3.connect(opts.database_path) + cur = conn.cursor() + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + result = cur.fetchone() + conn.close() + return result[0] diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index 58e74f3..b4901b3 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -2,4 +2,37 @@ from flask_caching import Cache from redis import Redis cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) -redis = Redis() + + +# redis = Redis() + +class RedisWrapper: + """ + A wrapper class to set prefixes to keys. + """ + + def __init__(self, prefix, **kwargs): + self.redis = Redis(**kwargs) + self.prefix = prefix + + def set(self, key, value): + return self.redis.set(f"{self.prefix}:{key}", value) + + def get(self, key): + return self.redis.get(f"{self.prefix}:{key}") + + def incr(self, key, amount=1): + return self.redis.incr(f"{self.prefix}:{key}", amount) + + def decr(self, key, amount=1): + return self.redis.decr(f"{self.prefix}:{key}", amount) + + def flush(self): + flushed = [] + for key in self.redis.scan_iter(f'{self.prefix}:*'): + flushed.append(key) + self.redis.delete(key) + return flushed + + +redis = RedisWrapper('local_llm') diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 530a3a0..1df6bf2 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -3,6 +3,7 @@ import threading import time from llm_server.llm.generator import generator +from llm_server.routes.cache import redis from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock @@ -40,8 +41,12 @@ class DataEvent(threading.Event): def worker(): + global active_gen_workers while True: priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get() + + redis.incr('active_gen_workers') + start_time = time.time() success, response, error_msg = generator(request_json_body) @@ -53,6 +58,8 @@ def worker(): event.data = (success, response, error_msg) event.set() + redis.decr('active_gen_workers') + def start_workers(num_workers: int): for _ in range(num_workers): diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index 4cb4ba9..cf07127 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -52,7 +52,7 @@ def process_avg_gen_time(): time.sleep(5) -def get_count(): +def get_total_proompts(): count = redis.get('proompts') if count is None: count = 0 @@ -61,6 +61,15 @@ def get_count(): return count +def get_active_gen_workers(): + active_gen_workers = redis.get('active_gen_workers') + if active_gen_workers is None: + count = 0 + else: + count = int(active_gen_workers) + return count + + class SemaphoreCheckerThread(Thread): proompters_1_min = 0 recent_prompters = {} diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index a93814f..e516fc3 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -76,7 +76,7 @@ def generate(): else: raise Exception - log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'failed to reach backend', @@ -95,7 +95,7 @@ def generate(): else: raise Exception - log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ **response_json_body }), 200 @@ -111,7 +111,7 @@ def generate(): } else: raise Exception - log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'the backend did not return valid JSON', diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index d2346b5..6d954b5 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -6,14 +6,13 @@ from flask import jsonify, request from llm_server import opts from . import bp from .. import stats -from ..cache import cache from ..queue import priority_queue -from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time +from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers from ...llm.info import get_running_model @bp.route('/stats', methods=['GET']) -@cache.cached(timeout=5, query_string=True) +# @cache.cached(timeout=5, query_string=True) def get_stats(): model_list, error = get_running_model() # will return False when the fetch fails if isinstance(model_list, bool): @@ -29,12 +28,13 @@ def get_stats(): # estimated_wait = int(sum(waits) / len(waits)) average_generation_time = int(calculate_avg_gen_time()) + proompters_in_queue = len(priority_queue) + get_active_gen_workers() return jsonify({ 'stats': { - 'prompts_in_queue': len(priority_queue), + 'prompts_in_queue': proompters_in_queue, 'proompters_1_min': SemaphoreCheckerThread.proompters_1_min, - 'total_proompts': stats.get_count(), + 'total_proompts': stats.get_total_proompts(), 'uptime': int((datetime.now() - stats.server_start_time).total_seconds()), 'average_generation_elapsed_sec': average_generation_time, }, @@ -44,7 +44,7 @@ def get_stats(): 'endpoints': { 'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}', }, - 'estimated_wait_sec': int(average_generation_time * len(priority_queue)), + 'estimated_wait_sec': int(average_generation_time * proompters_in_queue), 'timestamp': int(time.time()), 'openaiKeys': '∞', 'anthropicKeys': '∞', diff --git a/server.py b/server.py index c5cc41f..1cbeeaa 100644 --- a/server.py +++ b/server.py @@ -7,9 +7,9 @@ from flask import Flask, jsonify from llm_server import opts from llm_server.config import ConfigLoader -from llm_server.database import init_db +from llm_server.database import get_number_of_rows, init_db from llm_server.helpers import resolve_path -from llm_server.routes.cache import cache +from llm_server.routes.cache import cache, redis from llm_server.routes.helpers.http import cache_control from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time @@ -23,7 +23,7 @@ if config_path_environ: else: config_path = Path(script_path, 'config', 'config.yml') -default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True} +default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True, 'load_num_prompts': False} required_vars = ['token_limit'] config_loader = ConfigLoader(config_path, default_vars, required_vars) success, config, msg = config_loader.load_config() @@ -38,7 +38,7 @@ if config['database_path'].startswith('./'): config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) opts.database_path = resolve_path(config['database_path']) -init_db(opts.database_path) +init_db() if config['mode'] not in ['oobabooga', 'hf-textgen']: print('Unknown mode:', config['mode']) @@ -55,6 +55,12 @@ if not opts.verify_ssl: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) +flushed_keys = redis.flush() +print('Flushed', len(flushed_keys), 'keys from Redis.') + +if config['load_num_prompts']: + redis.set('proompts', get_number_of_rows('prompts')) + start_workers(opts.concurrent_gens) # cleanup_thread = Thread(target=elapsed_times_cleanup)