fix some stuff related to gunicorn workers
This commit is contained in:
parent
02c07bbd53
commit
11a0b6541f
|
@ -10,6 +10,9 @@ token_limit: 7777
|
|||
|
||||
backend_url: https://10.0.0.86:8083
|
||||
|
||||
# Load the number of prompts from the database to display on the stats page.
|
||||
load_num_prompts: true
|
||||
|
||||
# Path that is shown to users for them to connect to
|
||||
frontend_api_client: /api
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@ from llm_server import opts
|
|||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def init_db(db_path):
|
||||
if not Path(db_path).exists():
|
||||
conn = sqlite3.connect(db_path)
|
||||
def init_db():
|
||||
if not Path(opts.database_path).exists():
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
CREATE TABLE prompts (
|
||||
|
@ -43,7 +43,7 @@ def init_db(db_path):
|
|||
conn.close()
|
||||
|
||||
|
||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
|
||||
def log_prompt(ip, token, prompt, response, parameters, headers, backend_response_code):
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
|
||||
|
@ -51,7 +51,7 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backen
|
|||
prompt = response = None
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
c = conn.cursor()
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
|
@ -82,3 +82,12 @@ def increment_uses(api_key):
|
|||
conn.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
|
||||
result = cur.fetchone()
|
||||
conn.close()
|
||||
return result[0]
|
||||
|
|
|
@ -2,4 +2,37 @@ from flask_caching import Cache
|
|||
from redis import Redis
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
||||
redis = Redis()
|
||||
|
||||
|
||||
# redis = Redis()
|
||||
|
||||
class RedisWrapper:
|
||||
"""
|
||||
A wrapper class to set prefixes to keys.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, **kwargs):
|
||||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
|
||||
def set(self, key, value):
|
||||
return self.redis.set(f"{self.prefix}:{key}", value)
|
||||
|
||||
def get(self, key):
|
||||
return self.redis.get(f"{self.prefix}:{key}")
|
||||
|
||||
def incr(self, key, amount=1):
|
||||
return self.redis.incr(f"{self.prefix}:{key}", amount)
|
||||
|
||||
def decr(self, key, amount=1):
|
||||
return self.redis.decr(f"{self.prefix}:{key}", amount)
|
||||
|
||||
def flush(self):
|
||||
flushed = []
|
||||
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
||||
flushed.append(key)
|
||||
self.redis.delete(key)
|
||||
return flushed
|
||||
|
||||
|
||||
redis = RedisWrapper('local_llm')
|
||||
|
|
|
@ -3,6 +3,7 @@ import threading
|
|||
import time
|
||||
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||
|
||||
|
||||
|
@ -40,8 +41,12 @@ class DataEvent(threading.Event):
|
|||
|
||||
|
||||
def worker():
|
||||
global active_gen_workers
|
||||
while True:
|
||||
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
|
||||
redis.incr('active_gen_workers')
|
||||
|
||||
start_time = time.time()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
|
||||
|
@ -53,6 +58,8 @@ def worker():
|
|||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
redis.decr('active_gen_workers')
|
||||
|
||||
|
||||
def start_workers(num_workers: int):
|
||||
for _ in range(num_workers):
|
||||
|
|
|
@ -52,7 +52,7 @@ def process_avg_gen_time():
|
|||
time.sleep(5)
|
||||
|
||||
|
||||
def get_count():
|
||||
def get_total_proompts():
|
||||
count = redis.get('proompts')
|
||||
if count is None:
|
||||
count = 0
|
||||
|
@ -61,6 +61,15 @@ def get_count():
|
|||
return count
|
||||
|
||||
|
||||
def get_active_gen_workers():
|
||||
active_gen_workers = redis.get('active_gen_workers')
|
||||
if active_gen_workers is None:
|
||||
count = 0
|
||||
else:
|
||||
count = int(active_gen_workers)
|
||||
return count
|
||||
|
||||
|
||||
class SemaphoreCheckerThread(Thread):
|
||||
proompters_1_min = 0
|
||||
recent_prompters = {}
|
||||
|
|
|
@ -76,7 +76,7 @@ def generate():
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend',
|
||||
|
@ -95,7 +95,7 @@ def generate():
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
@ -111,7 +111,7 @@ def generate():
|
|||
}
|
||||
else:
|
||||
raise Exception
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'the backend did not return valid JSON',
|
||||
|
|
|
@ -6,14 +6,13 @@ from flask import jsonify, request
|
|||
from llm_server import opts
|
||||
from . import bp
|
||||
from .. import stats
|
||||
from ..cache import cache
|
||||
from ..queue import priority_queue
|
||||
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
|
||||
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
@bp.route('/stats', methods=['GET'])
|
||||
@cache.cached(timeout=5, query_string=True)
|
||||
# @cache.cached(timeout=5, query_string=True)
|
||||
def get_stats():
|
||||
model_list, error = get_running_model() # will return False when the fetch fails
|
||||
if isinstance(model_list, bool):
|
||||
|
@ -29,12 +28,13 @@ def get_stats():
|
|||
# estimated_wait = int(sum(waits) / len(waits))
|
||||
|
||||
average_generation_time = int(calculate_avg_gen_time())
|
||||
proompters_in_queue = len(priority_queue) + get_active_gen_workers()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'prompts_in_queue': len(priority_queue),
|
||||
'prompts_in_queue': proompters_in_queue,
|
||||
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
|
||||
'total_proompts': stats.get_count(),
|
||||
'total_proompts': stats.get_total_proompts(),
|
||||
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
|
||||
'average_generation_elapsed_sec': average_generation_time,
|
||||
},
|
||||
|
@ -44,7 +44,7 @@ def get_stats():
|
|||
'endpoints': {
|
||||
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
|
||||
},
|
||||
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
|
||||
'estimated_wait_sec': int(average_generation_time * proompters_in_queue),
|
||||
'timestamp': int(time.time()),
|
||||
'openaiKeys': '∞',
|
||||
'anthropicKeys': '∞',
|
||||
|
|
14
server.py
14
server.py
|
@ -7,9 +7,9 @@ from flask import Flask, jsonify
|
|||
|
||||
from llm_server import opts
|
||||
from llm_server.config import ConfigLoader
|
||||
from llm_server.database import init_db
|
||||
from llm_server.database import get_number_of_rows, init_db
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.helpers.http import cache_control
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
|
||||
|
@ -23,7 +23,7 @@ if config_path_environ:
|
|||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True}
|
||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True, 'load_num_prompts': False}
|
||||
required_vars = ['token_limit']
|
||||
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
|
@ -38,7 +38,7 @@ if config['database_path'].startswith('./'):
|
|||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||
|
||||
opts.database_path = resolve_path(config['database_path'])
|
||||
init_db(opts.database_path)
|
||||
init_db()
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
|
@ -55,6 +55,12 @@ if not opts.verify_ssl:
|
|||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
||||
|
|
Reference in New Issue