convert to gunicorn

This commit is contained in:
Cyberes 2023-09-26 13:32:33 -06:00
parent 0eb901cb52
commit e0af2ea9c5
10 changed files with 76 additions and 52 deletions

13
gunicorn.py Normal file
View File

@ -0,0 +1,13 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import server
def on_starting(s):
server.pre_fork(s)
print('Startup complete!')

View File

@ -6,6 +6,7 @@ import llm_server
from llm_server import opts
from llm_server.database.conn import db_pool
from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
@ -33,6 +34,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token:
increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time())
conn = db_pool.connection()
cursor = conn.cursor()
@ -42,7 +45,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()

View File

@ -10,6 +10,7 @@ import requests
import llm_server
from llm_server import opts
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
@ -49,13 +50,14 @@ def transform_to_text(json_request, api_response):
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"model": running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,

View File

@ -2,7 +2,7 @@
# TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'none'
running_model = 'ERROR'
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None

View File

@ -6,6 +6,7 @@ import traceback
from flask import Response, jsonify, request
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
@ -50,7 +51,7 @@ def openai_chat_completions():
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model')
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():

View File

@ -4,7 +4,7 @@ import requests
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache
from ..cache import ONE_MONTH_SECONDS, cache, redis
from ..stats import server_start_time
from ... import opts
from ...helpers import jsonify_pretty
@ -22,6 +22,7 @@ def openai_list_models():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', str, 'ERROR')
oai = fetch_openai_models()
r = []
if opts.openai_expose_our_model:
@ -29,13 +30,13 @@ def openai_list_models():
"object": "list",
"data": [
{
"id": opts.running_model,
"id": running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": opts.running_model,
"id": running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,

View File

@ -14,6 +14,7 @@ from flask import jsonify
import llm_server
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
@ -157,11 +158,13 @@ def build_openai_response(prompt, response, model=None):
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
return jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model if opts.openai_expose_our_model else model,
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {

View File

@ -46,7 +46,7 @@ def generate_stats(regen: bool = False):
online = False
else:
online = True
opts.running_model = model_name
redis.set('running_model', model_name)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:

View File

@ -25,35 +25,36 @@ class MainBackgroundThread(Thread):
def run(self):
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
model, err = get_running_model()
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
opts.running_model = model
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
model, err = get_running_model()
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
opts.running_model = model
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)

View File

@ -15,6 +15,8 @@ from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
@ -22,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
# TODO: validate system tokens before excluding them
# TODO: make sure prompts are logged even when the user cancels generation
# TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: use the current estimated wait time for ratelimit headers on openai
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead
@ -43,13 +46,18 @@ from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread, cache_stats
script_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
cache.init_app(app)
cache.clear() # clear redis cache
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
@ -73,9 +81,6 @@ if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
# TODO: this is a MESS
opts.mode = config['mode']
opts.auth_required = config['auth_required']
@ -108,23 +113,12 @@ if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
print('Set host to', redis.get('http_host', str))
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
redis.set('backend_mode', opts.mode)
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
@ -138,9 +132,19 @@ elif opts.mode == 'vllm':
else:
raise Exception
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
def pre_fork(server):
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
redis.set('backend_mode', opts.mode)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
# Start background processes
start_workers(opts.concurrent_gens)
@ -150,19 +154,15 @@ process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
# print(app.url_map)
@ -177,7 +177,7 @@ def home():
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = opts.running_model
running_model = redis.get('running_model', str, 'ERROR')
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: