fix some stuff related to gunicorn workers

This commit is contained in:
Cyberes 2023-08-23 22:01:06 -06:00
parent 02c07bbd53
commit 11a0b6541f
8 changed files with 87 additions and 20 deletions

View File

@ -10,6 +10,9 @@ token_limit: 7777
backend_url: https://10.0.0.86:8083
# Load the number of prompts from the database to display on the stats page.
load_num_prompts: true
# Path that is shown to users for them to connect to
frontend_api_client: /api

View File

@ -10,9 +10,9 @@ from llm_server import opts
tokenizer = tiktoken.get_encoding("cl100k_base")
def init_db(db_path):
if not Path(db_path).exists():
conn = sqlite3.connect(db_path)
def init_db():
if not Path(opts.database_path).exists():
conn = sqlite3.connect(opts.database_path)
c = conn.cursor()
c.execute('''
CREATE TABLE prompts (
@ -43,7 +43,7 @@ def init_db(db_path):
conn.close()
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
def log_prompt(ip, token, prompt, response, parameters, headers, backend_response_code):
prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response))
@ -51,7 +51,7 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backen
prompt = response = None
timestamp = int(time.time())
conn = sqlite3.connect(db_path)
conn = sqlite3.connect(opts.database_path)
c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
@ -82,3 +82,12 @@ def increment_uses(api_key):
conn.commit()
return True
return False
def get_number_of_rows(table_name):
conn = sqlite3.connect(opts.database_path)
cur = conn.cursor()
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
result = cur.fetchone()
conn.close()
return result[0]

View File

@ -2,4 +2,37 @@ from flask_caching import Cache
from redis import Redis
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
redis = Redis()
# redis = Redis()
class RedisWrapper:
"""
A wrapper class to set prefixes to keys.
"""
def __init__(self, prefix, **kwargs):
self.redis = Redis(**kwargs)
self.prefix = prefix
def set(self, key, value):
return self.redis.set(f"{self.prefix}:{key}", value)
def get(self, key):
return self.redis.get(f"{self.prefix}:{key}")
def incr(self, key, amount=1):
return self.redis.incr(f"{self.prefix}:{key}", amount)
def decr(self, key, amount=1):
return self.redis.decr(f"{self.prefix}:{key}", amount)
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
flushed.append(key)
self.redis.delete(key)
return flushed
redis = RedisWrapper('local_llm')

View File

@ -3,6 +3,7 @@ import threading
import time
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
@ -40,8 +41,12 @@ class DataEvent(threading.Event):
def worker():
global active_gen_workers
while True:
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.incr('active_gen_workers')
start_time = time.time()
success, response, error_msg = generator(request_json_body)
@ -53,6 +58,8 @@ def worker():
event.data = (success, response, error_msg)
event.set()
redis.decr('active_gen_workers')
def start_workers(num_workers: int):
for _ in range(num_workers):

View File

@ -52,7 +52,7 @@ def process_avg_gen_time():
time.sleep(5)
def get_count():
def get_total_proompts():
count = redis.get('proompts')
if count is None:
count = 0
@ -61,6 +61,15 @@ def get_count():
return count
def get_active_gen_workers():
active_gen_workers = redis.get('active_gen_workers')
if active_gen_workers is None:
count = 0
else:
count = int(active_gen_workers)
return count
class SemaphoreCheckerThread(Thread):
proompters_1_min = 0
recent_prompters = {}

View File

@ -76,7 +76,7 @@ def generate():
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
@ -95,7 +95,7 @@ def generate():
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
**response_json_body
}), 200
@ -111,7 +111,7 @@ def generate():
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',

View File

@ -6,14 +6,13 @@ from flask import jsonify, request
from llm_server import opts
from . import bp
from .. import stats
from ..cache import cache
from ..queue import priority_queue
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers
from ...llm.info import get_running_model
@bp.route('/stats', methods=['GET'])
@cache.cached(timeout=5, query_string=True)
# @cache.cached(timeout=5, query_string=True)
def get_stats():
model_list, error = get_running_model() # will return False when the fetch fails
if isinstance(model_list, bool):
@ -29,12 +28,13 @@ def get_stats():
# estimated_wait = int(sum(waits) / len(waits))
average_generation_time = int(calculate_avg_gen_time())
proompters_in_queue = len(priority_queue) + get_active_gen_workers()
return jsonify({
'stats': {
'prompts_in_queue': len(priority_queue),
'prompts_in_queue': proompters_in_queue,
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
'total_proompts': stats.get_count(),
'total_proompts': stats.get_total_proompts(),
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
'average_generation_elapsed_sec': average_generation_time,
},
@ -44,7 +44,7 @@ def get_stats():
'endpoints': {
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
},
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
'estimated_wait_sec': int(average_generation_time * proompters_in_queue),
'timestamp': int(time.time()),
'openaiKeys': '',
'anthropicKeys': '',

View File

@ -7,9 +7,9 @@ from flask import Flask, jsonify
from llm_server import opts
from llm_server.config import ConfigLoader
from llm_server.database import init_db
from llm_server.database import get_number_of_rows, init_db
from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache
from llm_server.routes.cache import cache, redis
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
@ -23,7 +23,7 @@ if config_path_environ:
else:
config_path = Path(script_path, 'config', 'config.yml')
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True}
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True, 'load_num_prompts': False}
required_vars = ['token_limit']
config_loader = ConfigLoader(config_path, default_vars, required_vars)
success, config, msg = config_loader.load_config()
@ -38,7 +38,7 @@ if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
opts.database_path = resolve_path(config['database_path'])
init_db(opts.database_path)
init_db()
if config['mode'] not in ['oobabooga', 'hf-textgen']:
print('Unknown mode:', config['mode'])
@ -55,6 +55,12 @@ if not opts.verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup)