fix some stuff related to gunicorn workers
This commit is contained in:
parent
02c07bbd53
commit
11a0b6541f
|
@ -10,6 +10,9 @@ token_limit: 7777
|
||||||
|
|
||||||
backend_url: https://10.0.0.86:8083
|
backend_url: https://10.0.0.86:8083
|
||||||
|
|
||||||
|
# Load the number of prompts from the database to display on the stats page.
|
||||||
|
load_num_prompts: true
|
||||||
|
|
||||||
# Path that is shown to users for them to connect to
|
# Path that is shown to users for them to connect to
|
||||||
frontend_api_client: /api
|
frontend_api_client: /api
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ from llm_server import opts
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
def init_db(db_path):
|
def init_db():
|
||||||
if not Path(db_path).exists():
|
if not Path(opts.database_path).exists():
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(opts.database_path)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
CREATE TABLE prompts (
|
CREATE TABLE prompts (
|
||||||
|
@ -43,7 +43,7 @@ def init_db(db_path):
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
|
def log_prompt(ip, token, prompt, response, parameters, headers, backend_response_code):
|
||||||
prompt_tokens = len(tokenizer.encode(prompt))
|
prompt_tokens = len(tokenizer.encode(prompt))
|
||||||
response_tokens = len(tokenizer.encode(response))
|
response_tokens = len(tokenizer.encode(response))
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backen
|
||||||
prompt = response = None
|
prompt = response = None
|
||||||
|
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(opts.database_path)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||||
|
@ -82,3 +82,12 @@ def increment_uses(api_key):
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_number_of_rows(table_name):
|
||||||
|
conn = sqlite3.connect(opts.database_path)
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
|
||||||
|
result = cur.fetchone()
|
||||||
|
conn.close()
|
||||||
|
return result[0]
|
||||||
|
|
|
@ -2,4 +2,37 @@ from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
||||||
redis = Redis()
|
|
||||||
|
|
||||||
|
# redis = Redis()
|
||||||
|
|
||||||
|
class RedisWrapper:
|
||||||
|
"""
|
||||||
|
A wrapper class to set prefixes to keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, **kwargs):
|
||||||
|
self.redis = Redis(**kwargs)
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def set(self, key, value):
|
||||||
|
return self.redis.set(f"{self.prefix}:{key}", value)
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
return self.redis.get(f"{self.prefix}:{key}")
|
||||||
|
|
||||||
|
def incr(self, key, amount=1):
|
||||||
|
return self.redis.incr(f"{self.prefix}:{key}", amount)
|
||||||
|
|
||||||
|
def decr(self, key, amount=1):
|
||||||
|
return self.redis.decr(f"{self.prefix}:{key}", amount)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
flushed = []
|
||||||
|
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
||||||
|
flushed.append(key)
|
||||||
|
self.redis.delete(key)
|
||||||
|
return flushed
|
||||||
|
|
||||||
|
|
||||||
|
redis = RedisWrapper('local_llm')
|
||||||
|
|
|
@ -3,6 +3,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,8 +41,12 @@ class DataEvent(threading.Event):
|
||||||
|
|
||||||
|
|
||||||
def worker():
|
def worker():
|
||||||
|
global active_gen_workers
|
||||||
while True:
|
while True:
|
||||||
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||||
|
|
||||||
|
redis.incr('active_gen_workers')
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
success, response, error_msg = generator(request_json_body)
|
success, response, error_msg = generator(request_json_body)
|
||||||
|
|
||||||
|
@ -53,6 +58,8 @@ def worker():
|
||||||
event.data = (success, response, error_msg)
|
event.data = (success, response, error_msg)
|
||||||
event.set()
|
event.set()
|
||||||
|
|
||||||
|
redis.decr('active_gen_workers')
|
||||||
|
|
||||||
|
|
||||||
def start_workers(num_workers: int):
|
def start_workers(num_workers: int):
|
||||||
for _ in range(num_workers):
|
for _ in range(num_workers):
|
||||||
|
|
|
@ -52,7 +52,7 @@ def process_avg_gen_time():
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
def get_count():
|
def get_total_proompts():
|
||||||
count = redis.get('proompts')
|
count = redis.get('proompts')
|
||||||
if count is None:
|
if count is None:
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -61,6 +61,15 @@ def get_count():
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_gen_workers():
|
||||||
|
active_gen_workers = redis.get('active_gen_workers')
|
||||||
|
if active_gen_workers is None:
|
||||||
|
count = 0
|
||||||
|
else:
|
||||||
|
count = int(active_gen_workers)
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
class SemaphoreCheckerThread(Thread):
|
class SemaphoreCheckerThread(Thread):
|
||||||
proompters_1_min = 0
|
proompters_1_min = 0
|
||||||
recent_prompters = {}
|
recent_prompters = {}
|
||||||
|
|
|
@ -76,7 +76,7 @@ def generate():
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'code': 500,
|
'code': 500,
|
||||||
'error': 'failed to reach backend',
|
'error': 'failed to reach backend',
|
||||||
|
@ -95,7 +95,7 @@ def generate():
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
**response_json_body
|
**response_json_body
|
||||||
}), 200
|
}), 200
|
||||||
|
@ -111,7 +111,7 @@ def generate():
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'code': 500,
|
'code': 500,
|
||||||
'error': 'the backend did not return valid JSON',
|
'error': 'the backend did not return valid JSON',
|
||||||
|
|
|
@ -6,14 +6,13 @@ from flask import jsonify, request
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from . import bp
|
from . import bp
|
||||||
from .. import stats
|
from .. import stats
|
||||||
from ..cache import cache
|
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
|
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers
|
||||||
from ...llm.info import get_running_model
|
from ...llm.info import get_running_model
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/stats', methods=['GET'])
|
@bp.route('/stats', methods=['GET'])
|
||||||
@cache.cached(timeout=5, query_string=True)
|
# @cache.cached(timeout=5, query_string=True)
|
||||||
def get_stats():
|
def get_stats():
|
||||||
model_list, error = get_running_model() # will return False when the fetch fails
|
model_list, error = get_running_model() # will return False when the fetch fails
|
||||||
if isinstance(model_list, bool):
|
if isinstance(model_list, bool):
|
||||||
|
@ -29,12 +28,13 @@ def get_stats():
|
||||||
# estimated_wait = int(sum(waits) / len(waits))
|
# estimated_wait = int(sum(waits) / len(waits))
|
||||||
|
|
||||||
average_generation_time = int(calculate_avg_gen_time())
|
average_generation_time = int(calculate_avg_gen_time())
|
||||||
|
proompters_in_queue = len(priority_queue) + get_active_gen_workers()
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'stats': {
|
'stats': {
|
||||||
'prompts_in_queue': len(priority_queue),
|
'prompts_in_queue': proompters_in_queue,
|
||||||
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
|
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
|
||||||
'total_proompts': stats.get_count(),
|
'total_proompts': stats.get_total_proompts(),
|
||||||
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
|
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
|
||||||
'average_generation_elapsed_sec': average_generation_time,
|
'average_generation_elapsed_sec': average_generation_time,
|
||||||
},
|
},
|
||||||
|
@ -44,7 +44,7 @@ def get_stats():
|
||||||
'endpoints': {
|
'endpoints': {
|
||||||
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
|
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
|
||||||
},
|
},
|
||||||
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
|
'estimated_wait_sec': int(average_generation_time * proompters_in_queue),
|
||||||
'timestamp': int(time.time()),
|
'timestamp': int(time.time()),
|
||||||
'openaiKeys': '∞',
|
'openaiKeys': '∞',
|
||||||
'anthropicKeys': '∞',
|
'anthropicKeys': '∞',
|
||||||
|
|
14
server.py
14
server.py
|
@ -7,9 +7,9 @@ from flask import Flask, jsonify
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.config import ConfigLoader
|
from llm_server.config import ConfigLoader
|
||||||
from llm_server.database import init_db
|
from llm_server.database import get_number_of_rows, init_db
|
||||||
from llm_server.helpers import resolve_path
|
from llm_server.helpers import resolve_path
|
||||||
from llm_server.routes.cache import cache
|
from llm_server.routes.cache import cache, redis
|
||||||
from llm_server.routes.helpers.http import cache_control
|
from llm_server.routes.helpers.http import cache_control
|
||||||
from llm_server.routes.queue import start_workers
|
from llm_server.routes.queue import start_workers
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
|
||||||
|
@ -23,7 +23,7 @@ if config_path_environ:
|
||||||
else:
|
else:
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
config_path = Path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True}
|
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': '', 'verify_ssl': True, 'load_num_prompts': False}
|
||||||
required_vars = ['token_limit']
|
required_vars = ['token_limit']
|
||||||
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
||||||
success, config, msg = config_loader.load_config()
|
success, config, msg = config_loader.load_config()
|
||||||
|
@ -38,7 +38,7 @@ if config['database_path'].startswith('./'):
|
||||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||||
|
|
||||||
opts.database_path = resolve_path(config['database_path'])
|
opts.database_path = resolve_path(config['database_path'])
|
||||||
init_db(opts.database_path)
|
init_db()
|
||||||
|
|
||||||
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
||||||
print('Unknown mode:', config['mode'])
|
print('Unknown mode:', config['mode'])
|
||||||
|
@ -55,6 +55,12 @@ if not opts.verify_ssl:
|
||||||
|
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
flushed_keys = redis.flush()
|
||||||
|
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||||
|
|
||||||
|
if config['load_num_prompts']:
|
||||||
|
redis.set('proompts', get_number_of_rows('prompts'))
|
||||||
|
|
||||||
start_workers(opts.concurrent_gens)
|
start_workers(opts.concurrent_gens)
|
||||||
|
|
||||||
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
||||||
|
|
Reference in New Issue