convert to gunicorn

This commit is contained in:
Cyberes 2023-09-26 13:32:33 -06:00
parent 0eb901cb52
commit e0af2ea9c5
10 changed files with 76 additions and 52 deletions

13
gunicorn.py Normal file
View File

@ -0,0 +1,13 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import server
def on_starting(s):
server.pre_fork(s)
print('Startup complete!')

View File

@ -6,6 +6,7 @@ import llm_server
from llm_server import opts from llm_server import opts
from llm_server.database.conn import db_pool from llm_server.database.conn import db_pool
from llm_server.llm.vllm import tokenize from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
@ -33,6 +34,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token: if token:
increment_token_uses(token) increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time()) timestamp = int(time.time())
conn = db_pool.connection() conn = db_pool.connection()
cursor = conn.cursor() cursor = conn.cursor()
@ -42,7 +45,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally: finally:
cursor.close() cursor.close()

View File

@ -10,6 +10,7 @@ import requests
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed
@ -49,13 +50,14 @@ def transform_to_text(json_request, api_response):
prompt_tokens = len(llm_server.llm.get_token_count(prompt)) prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text)) completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
# https://platform.openai.com/docs/api-reference/making-requests?lang=python # https://platform.openai.com/docs/api-reference/making-requests?lang=python
return { return {
"id": str(uuid4()), "id": str(uuid4()),
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": opts.running_model, "model": running_model,
"usage": { "usage": {
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,

View File

@ -2,7 +2,7 @@
# TODO: rewrite the config system so I don't have to add every single config default here # TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'none' running_model = 'ERROR'
concurrent_gens = 3 concurrent_gens = 3
mode = 'oobabooga' mode = 'oobabooga'
backend_url = None backend_url = None

View File

@ -6,6 +6,7 @@ import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from . import openai_bp from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
@ -50,7 +51,7 @@ def openai_chat_completions():
response = generator(msg_to_backend) response = generator(msg_to_backend)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():

View File

@ -4,7 +4,7 @@ import requests
from flask import jsonify from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache from ..cache import ONE_MONTH_SECONDS, cache, redis
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
@ -22,6 +22,7 @@ def openai_list_models():
'type': error.__class__.__name__ 'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
running_model = redis.get('running_model', str, 'ERROR')
oai = fetch_openai_models() oai = fetch_openai_models()
r = [] r = []
if opts.openai_expose_our_model: if opts.openai_expose_our_model:
@ -29,13 +30,13 @@ def openai_list_models():
"object": "list", "object": "list",
"data": [ "data": [
{ {
"id": opts.running_model, "id": running_model,
"object": "model", "object": "model",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name, "owned_by": opts.llm_middleware_name,
"permission": [ "permission": [
{ {
"id": opts.running_model, "id": running_model,
"object": "model_permission", "object": "model_permission",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"allow_create_engine": False, "allow_create_engine": False,

View File

@ -14,6 +14,7 @@ from flask import jsonify
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.database.database import log_prompt from llm_server.database.database import log_prompt
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -157,11 +158,13 @@ def build_openai_response(prompt, response, model=None):
# TODO: async/await # TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt) prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response) response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
return jsonify({ return jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}", "id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": opts.running_model if opts.openai_expose_our_model else model, "model": running_model if opts.openai_expose_our_model else model,
"choices": [{ "choices": [{
"index": 0, "index": 0,
"message": { "message": {

View File

@ -46,7 +46,7 @@ def generate_stats(regen: bool = False):
online = False online = False
else: else:
online = True online = True
opts.running_model = model_name redis.set('running_model', model_name)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0: # if len(t) == 0:

View File

@ -25,35 +25,36 @@ class MainBackgroundThread(Thread):
def run(self): def run(self):
while True: while True:
# TODO: unify this
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
model, err = get_running_model() running_model, err = get_running_model()
if err: if err:
print(err) print(err)
redis.set('backend_online', 0) redis.set('backend_online', 0)
else: else:
opts.running_model = model redis.set('running_model', running_model)
redis.set('backend_online', 1) redis.set('backend_online', 1)
elif opts.mode == 'vllm': elif opts.mode == 'vllm':
model, err = get_running_model() running_model, err = get_running_model()
if err: if err:
print(err) print(err)
redis.set('backend_online', 0) redis.set('backend_online', 0)
else: else:
opts.running_model = model redis.set('running_model', running_model)
redis.set('backend_online', 1) redis.set('backend_online', 1)
else: else:
raise Exception raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now. # was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens) redis.set('average_output_tokens', average_output_tokens)

View File

@ -15,6 +15,8 @@ from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded) # TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint # TODO: support turbo-instruct on openai endpoint
@ -22,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
# TODO: validate system tokens before excluding them # TODO: validate system tokens before excluding them
# TODO: make sure prompts are logged even when the user cancels generation # TODO: make sure prompts are logged even when the user cancels generation
# TODO: add some sort of loadbalancer to send requests to a group of backends # TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: use the current estimated wait time for ratelimit headers on openai
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead # TODO: unify logging thread in a function and use async/await instead
@ -43,13 +46,18 @@ from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread, cache_stats from llm_server.threads import MainBackgroundThread, cache_stats
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
cache.init_app(app)
cache.clear() # clear redis cache
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ: if config_path_environ:
config_path = config_path_environ config_path = config_path_environ
@ -73,9 +81,6 @@ if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode']) print('Unknown mode:', config['mode'])
sys.exit(1) sys.exit(1)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
# TODO: this is a MESS # TODO: this is a MESS
opts.mode = config['mode'] opts.mode = config['mode']
opts.auth_required = config['auth_required'] opts.auth_required = config['auth_required']
@ -108,23 +113,12 @@ if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1) sys.exit(1)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
print('Set host to', redis.get('http_host', str))
opts.verify_ssl = config['verify_ssl'] opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: if not opts.verify_ssl:
import urllib3 import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
redis.set('backend_mode', opts.mode)
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
if config['average_generation_time_mode'] not in ['database', 'minute']: if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1) sys.exit(1)
@ -138,30 +132,36 @@ elif opts.mode == 'vllm':
else: else:
raise Exception raise Exception
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
# Start background processes def pre_fork(server):
start_workers(opts.concurrent_gens) flushed_keys = redis.flush()
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) print('Flushed', len(flushed_keys), 'keys from Redis.')
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# Cache the initial stats redis.set('backend_mode', opts.mode)
print('Loading backend stats...') if config['http_host']:
generate_stats() http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
init_socketio(app) if config['load_num_prompts']:
app.register_blueprint(bp, url_prefix='/api/v1/') redis.set('proompts', get_number_of_rows('prompts'))
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
# This needs to be started after Flask is initalized # Start background processes
stats_updater_thread = Thread(target=cache_stats) start_workers(opts.concurrent_gens)
stats_updater_thread.daemon = True process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
stats_updater_thread.start() process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
# print(app.url_map) # print(app.url_map)
@ -177,7 +177,7 @@ def home():
if not stats['online']: if not stats['online']:
running_model = estimated_wait_sec = 'offline' running_model = estimated_wait_sec = 'offline'
else: else:
running_model = opts.running_model running_model = redis.get('running_model', str, 'ERROR')
active_gen_workers = get_active_gen_workers() active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: