fix some stuff related to gunicorn workers

This commit is contained in:
Cyberes 2023-08-23 22:01:06 -06:00
parent 02c07bbd53
commit 11a0b6541f
8 changed files with 87 additions and 20 deletions

View File

@ -10,6 +10,9 @@ token_limit: 7777
backend_url: https://10.0.0.86:8083 backend_url: https://10.0.0.86:8083
# Load the number of prompts from the database to display on the stats page.
load_num_prompts: true
# Path that is shown to users for them to connect to # Path that is shown to users for them to connect to
frontend_api_client: /api frontend_api_client: /api

View File

@ -10,9 +10,9 @@ from llm_server import opts
tokenizer = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.get_encoding("cl100k_base")
def init_db(db_path): def init_db():
if not Path(db_path).exists(): if not Path(opts.database_path).exists():
conn = sqlite3.connect(db_path) conn = sqlite3.connect(opts.database_path)
c = conn.cursor() c = conn.cursor()
c.execute(''' c.execute('''
CREATE TABLE prompts ( CREATE TABLE prompts (
@ -43,7 +43,7 @@ def init_db(db_path):
conn.close() conn.close()
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code): def log_prompt(ip, token, prompt, response, parameters, headers, backend_response_code):
prompt_tokens = len(tokenizer.encode(prompt)) prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response)) response_tokens = len(tokenizer.encode(response))
@ -51,7 +51,7 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backen
prompt = response = None prompt = response = None
timestamp = int(time.time()) timestamp = int(time.time())
conn = sqlite3.connect(db_path) conn = sqlite3.connect(opts.database_path)
c = conn.cursor() c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
@ -82,3 +82,12 @@ def increment_uses(api_key):
conn.commit() conn.commit()
return True return True
return False return False
def get_number_of_rows(table_name):
conn = sqlite3.connect(opts.database_path)
cur = conn.cursor()
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
result = cur.fetchone()
conn.close()
return result[0]

View File

@ -2,4 +2,37 @@ from flask_caching import Cache
from redis import Redis from redis import Redis
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
redis = Redis()
# redis = Redis()
class RedisWrapper:
"""
A wrapper class to set prefixes to keys.
"""
def __init__(self, prefix, **kwargs):
self.redis = Redis(**kwargs)
self.prefix = prefix
def set(self, key, value):
return self.redis.set(f"{self.prefix}:{key}", value)
def get(self, key):
return self.redis.get(f"{self.prefix}:{key}")
def incr(self, key, amount=1):
return self.redis.incr(f"{self.prefix}:{key}", amount)
def decr(self, key, amount=1):
return self.redis.decr(f"{self.prefix}:{key}", amount)
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
flushed.append(key)
self.redis.delete(key)
return flushed
redis = RedisWrapper('local_llm')

View File

@ -3,6 +3,7 @@ import threading
import time import time
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
@ -40,8 +41,12 @@ class DataEvent(threading.Event):
def worker(): def worker():
global active_gen_workers
while True: while True:
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get() priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.incr('active_gen_workers')
start_time = time.time() start_time = time.time()
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
@ -53,6 +58,8 @@ def worker():
event.data = (success, response, error_msg) event.data = (success, response, error_msg)
event.set() event.set()
redis.decr('active_gen_workers')
def start_workers(num_workers: int): def start_workers(num_workers: int):
for _ in range(num_workers): for _ in range(num_workers):

View File

@ -52,7 +52,7 @@ def process_avg_gen_time():
time.sleep(5) time.sleep(5)
def get_count(): def get_total_proompts():
count = redis.get('proompts') count = redis.get('proompts')
if count is None: if count is None:
count = 0 count = 0
@ -61,6 +61,15 @@ def get_count():
return count return count
def get_active_gen_workers():
active_gen_workers = redis.get('active_gen_workers')
if active_gen_workers is None:
count = 0
else:
count = int(active_gen_workers)
return count
class SemaphoreCheckerThread(Thread): class SemaphoreCheckerThread(Thread):
proompters_1_min = 0 proompters_1_min = 0
recent_prompters = {} recent_prompters = {}

View File

@ -76,7 +76,7 @@ def generate():
else: else:
raise Exception raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'error': 'failed to reach backend', 'error': 'failed to reach backend',
@ -95,7 +95,7 @@ def generate():
else: else:
raise Exception raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
@ -111,7 +111,7 @@ def generate():
} }
else: else:
raise Exception raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'error': 'the backend did not return valid JSON', 'error': 'the backend did not return valid JSON',

View File

@ -6,14 +6,13 @@ from flask import jsonify, request
from llm_server import opts from llm_server import opts
from . import bp from . import bp
from .. import stats from .. import stats
from ..cache import cache
from ..queue import priority_queue from ..queue import priority_queue
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers
from ...llm.info import get_running_model from ...llm.info import get_running_model
@bp.route('/stats', methods=['GET']) @bp.route('/stats', methods=['GET'])
@cache.cached(timeout=5, query_string=True) # @cache.cached(timeout=5, query_string=True)
def get_stats(): def get_stats():
model_list, error = get_running_model() # will return False when the fetch fails model_list, error = get_running_model() # will return False when the fetch fails
if isinstance(model_list, bool): if isinstance(model_list, bool):
@ -29,12 +28,13 @@ def get_stats():
# estimated_wait = int(sum(waits) / len(waits)) # estimated_wait = int(sum(waits) / len(waits))
average_generation_time = int(calculate_avg_gen_time()) average_generation_time = int(calculate_avg_gen_time())
proompters_in_queue = len(priority_queue) + get_active_gen_workers()
return jsonify({ return jsonify({
'stats': { 'stats': {
'prompts_in_queue': len(priority_queue), 'prompts_in_queue': proompters_in_queue,
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min, 'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
'total_proompts': stats.get_count(), 'total_proompts': stats.get_total_proompts(),
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()), 'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
'average_generation_elapsed_sec': average_generation_time, 'average_generation_elapsed_sec': average_generation_time,
}, },
@ -44,7 +44,7 @@ def get_stats():
'endpoints': { 'endpoints': {
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}', 'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
}, },
'estimated_wait_sec': int(average_generation_time * len(priority_queue)), 'estimated_wait_sec': int(average_generation_time * proompters_in_queue),
'timestamp': int(time.time()), 'timestamp': int(time.time()),
'openaiKeys': '', 'openaiKeys': '',
'anthropicKeys': '', 'anthropicKeys': '',

View File

@ -7,9 +7,9 @@ from flask import Flask, jsonify
from llm_server import opts from llm_server import opts
from llm_server.config import ConfigLoader from llm_server.config import ConfigLoader
from llm_server.database import init_db from llm_server.database import get_number_of_rows, init_db
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache from llm_server.routes.cache import cache, redis
from llm_server.routes.helpers.http import cache_control from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
@ -23,7 +23,7 @@ if config_path_environ:
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True} default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True, 'load_num_prompts': False}
required_vars = ['token_limit'] required_vars = ['token_limit']
config_loader = ConfigLoader(config_path, default_vars, required_vars) config_loader = ConfigLoader(config_path, default_vars, required_vars)
success, config, msg = config_loader.load_config() success, config, msg = config_loader.load_config()
@ -38,7 +38,7 @@ if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
opts.database_path = resolve_path(config['database_path']) opts.database_path = resolve_path(config['database_path'])
init_db(opts.database_path) init_db()
if config['mode'] not in ['oobabooga', 'hf-textgen']: if config['mode'] not in ['oobabooga', 'hf-textgen']:
print('Unknown mode:', config['mode']) print('Unknown mode:', config['mode'])
@ -55,6 +55,12 @@ if not opts.verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
start_workers(opts.concurrent_gens) start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup) # cleanup_thread = Thread(target=elapsed_times_cleanup)