convert to gunicorn
This commit is contained in:
parent
0eb901cb52
commit
e0af2ea9c5
|
@ -0,0 +1,13 @@
|
||||||
|
try:
|
||||||
|
import gevent.monkey
|
||||||
|
|
||||||
|
gevent.monkey.patch_all()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import server
|
||||||
|
|
||||||
|
|
||||||
|
def on_starting(s):
|
||||||
|
server.pre_fork(s)
|
||||||
|
print('Startup complete!')
|
|
@ -6,6 +6,7 @@ import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.conn import db_pool
|
from llm_server.database.conn import db_pool
|
||||||
from llm_server.llm.vllm import tokenize
|
from llm_server.llm.vllm import tokenize
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||||
|
@ -33,6 +34,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
||||||
if token:
|
if token:
|
||||||
increment_token_uses(token)
|
increment_token_uses(token)
|
||||||
|
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
conn = db_pool.connection()
|
conn = db_pool.connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
@ -42,7 +45,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
||||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import requests
|
||||||
|
|
||||||
import llm_server
|
import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
# TODO: make the VLMM backend return TPS and time elapsed
|
# TODO: make the VLMM backend return TPS and time elapsed
|
||||||
|
@ -49,13 +50,14 @@ def transform_to_text(json_request, api_response):
|
||||||
|
|
||||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||||
return {
|
return {
|
||||||
"id": str(uuid4()),
|
"id": str(uuid4()),
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": opts.running_model,
|
"model": running_model,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": prompt_tokens,
|
"prompt_tokens": prompt_tokens,
|
||||||
"completion_tokens": completion_tokens,
|
"completion_tokens": completion_tokens,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||||
|
|
||||||
running_model = 'none'
|
running_model = 'ERROR'
|
||||||
concurrent_gens = 3
|
concurrent_gens = 3
|
||||||
mode = 'oobabooga'
|
mode = 'oobabooga'
|
||||||
backend_url = None
|
backend_url = None
|
||||||
|
|
|
@ -6,6 +6,7 @@ import traceback
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
|
from ..cache import redis
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
||||||
|
@ -50,7 +51,7 @@ def openai_chat_completions():
|
||||||
response = generator(msg_to_backend)
|
response = generator(msg_to_backend)
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model')
|
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
|
|
@ -4,7 +4,7 @@ import requests
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import ONE_MONTH_SECONDS, cache
|
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...helpers import jsonify_pretty
|
from ...helpers import jsonify_pretty
|
||||||
|
@ -22,6 +22,7 @@ def openai_list_models():
|
||||||
'type': error.__class__.__name__
|
'type': error.__class__.__name__
|
||||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||||
else:
|
else:
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
oai = fetch_openai_models()
|
oai = fetch_openai_models()
|
||||||
r = []
|
r = []
|
||||||
if opts.openai_expose_our_model:
|
if opts.openai_expose_our_model:
|
||||||
|
@ -29,13 +30,13 @@ def openai_list_models():
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"id": opts.running_model,
|
"id": running_model,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": int(server_start_time.timestamp()),
|
"created": int(server_start_time.timestamp()),
|
||||||
"owned_by": opts.llm_middleware_name,
|
"owned_by": opts.llm_middleware_name,
|
||||||
"permission": [
|
"permission": [
|
||||||
{
|
{
|
||||||
"id": opts.running_model,
|
"id": running_model,
|
||||||
"object": "model_permission",
|
"object": "model_permission",
|
||||||
"created": int(server_start_time.timestamp()),
|
"created": int(server_start_time.timestamp()),
|
||||||
"allow_create_engine": False,
|
"allow_create_engine": False,
|
||||||
|
|
|
@ -14,6 +14,7 @@ from flask import jsonify
|
||||||
import llm_server
|
import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.database import log_prompt
|
from llm_server.database.database import log_prompt
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
|
||||||
|
@ -157,11 +158,13 @@ def build_openai_response(prompt, response, model=None):
|
||||||
# TODO: async/await
|
# TODO: async/await
|
||||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||||
response_tokens = llm_server.llm.get_token_count(response)
|
response_tokens = llm_server.llm.get_token_count(response)
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": opts.running_model if opts.openai_expose_our_model else model,
|
"model": running_model if opts.openai_expose_our_model else model,
|
||||||
"choices": [{
|
"choices": [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
|
|
|
@ -46,7 +46,7 @@ def generate_stats(regen: bool = False):
|
||||||
online = False
|
online = False
|
||||||
else:
|
else:
|
||||||
online = True
|
online = True
|
||||||
opts.running_model = model_name
|
redis.set('running_model', model_name)
|
||||||
|
|
||||||
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
||||||
# if len(t) == 0:
|
# if len(t) == 0:
|
||||||
|
|
|
@ -25,35 +25,36 @@ class MainBackgroundThread(Thread):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
|
# TODO: unify this
|
||||||
if opts.mode == 'oobabooga':
|
if opts.mode == 'oobabooga':
|
||||||
model, err = get_running_model()
|
running_model, err = get_running_model()
|
||||||
if err:
|
if err:
|
||||||
print(err)
|
print(err)
|
||||||
redis.set('backend_online', 0)
|
redis.set('backend_online', 0)
|
||||||
else:
|
else:
|
||||||
opts.running_model = model
|
redis.set('running_model', running_model)
|
||||||
redis.set('backend_online', 1)
|
redis.set('backend_online', 1)
|
||||||
elif opts.mode == 'vllm':
|
elif opts.mode == 'vllm':
|
||||||
model, err = get_running_model()
|
running_model, err = get_running_model()
|
||||||
if err:
|
if err:
|
||||||
print(err)
|
print(err)
|
||||||
redis.set('backend_online', 0)
|
redis.set('backend_online', 0)
|
||||||
else:
|
else:
|
||||||
opts.running_model = model
|
redis.set('running_model', running_model)
|
||||||
redis.set('backend_online', 1)
|
redis.set('backend_online', 1)
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||||
if average_generation_elapsed_sec: # returns None on exception
|
if average_generation_elapsed_sec: # returns None on exception
|
||||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||||
|
|
||||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||||
|
|
||||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||||
if average_generation_elapsed_sec:
|
if average_generation_elapsed_sec:
|
||||||
redis.set('average_output_tokens', average_output_tokens)
|
redis.set('average_output_tokens', average_output_tokens)
|
||||||
|
|
||||||
|
|
74
server.py
74
server.py
|
@ -15,6 +15,8 @@ from llm_server.database.database import get_number_of_rows
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.routes.openai import openai_bp
|
from llm_server.routes.openai import openai_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
|
from llm_server.routes.v1 import bp
|
||||||
|
from llm_server.stream import init_socketio
|
||||||
|
|
||||||
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
||||||
# TODO: support turbo-instruct on openai endpoint
|
# TODO: support turbo-instruct on openai endpoint
|
||||||
|
@ -22,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
|
||||||
# TODO: validate system tokens before excluding them
|
# TODO: validate system tokens before excluding them
|
||||||
# TODO: make sure prompts are logged even when the user cancels generation
|
# TODO: make sure prompts are logged even when the user cancels generation
|
||||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||||
|
# TODO: use the current estimated wait time for ratelimit headers on openai
|
||||||
|
|
||||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||||
# TODO: unify logging thread in a function and use async/await instead
|
# TODO: unify logging thread in a function and use async/await instead
|
||||||
|
@ -43,13 +46,18 @@ from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.routes.cache import cache, redis
|
from llm_server.routes.cache import cache, redis
|
||||||
from llm_server.routes.queue import start_workers
|
from llm_server.routes.queue import start_workers
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
||||||
from llm_server.routes.v1 import bp
|
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.stream import init_socketio
|
|
||||||
from llm_server.threads import MainBackgroundThread, cache_stats
|
from llm_server.threads import MainBackgroundThread, cache_stats
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
init_socketio(app)
|
||||||
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
|
cache.init_app(app)
|
||||||
|
cache.clear() # clear redis cache
|
||||||
|
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
if config_path_environ:
|
if config_path_environ:
|
||||||
config_path = config_path_environ
|
config_path = config_path_environ
|
||||||
|
@ -73,9 +81,6 @@ if config['mode'] not in ['oobabooga', 'vllm']:
|
||||||
print('Unknown mode:', config['mode'])
|
print('Unknown mode:', config['mode'])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
flushed_keys = redis.flush()
|
|
||||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
|
||||||
|
|
||||||
# TODO: this is a MESS
|
# TODO: this is a MESS
|
||||||
opts.mode = config['mode']
|
opts.mode = config['mode']
|
||||||
opts.auth_required = config['auth_required']
|
opts.auth_required = config['auth_required']
|
||||||
|
@ -108,23 +113,12 @@ if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if config['http_host']:
|
|
||||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
|
||||||
redis.set('http_host', http_host)
|
|
||||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
|
||||||
print('Set host to', redis.get('http_host', str))
|
|
||||||
|
|
||||||
opts.verify_ssl = config['verify_ssl']
|
opts.verify_ssl = config['verify_ssl']
|
||||||
if not opts.verify_ssl:
|
if not opts.verify_ssl:
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
redis.set('backend_mode', opts.mode)
|
|
||||||
|
|
||||||
if config['load_num_prompts']:
|
|
||||||
redis.set('proompts', get_number_of_rows('prompts'))
|
|
||||||
|
|
||||||
if config['average_generation_time_mode'] not in ['database', 'minute']:
|
if config['average_generation_time_mode'] not in ['database', 'minute']:
|
||||||
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
|
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -138,30 +132,36 @@ elif opts.mode == 'vllm':
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
cache.init_app(app)
|
|
||||||
cache.clear() # clear redis cache
|
|
||||||
|
|
||||||
# Start background processes
|
def pre_fork(server):
|
||||||
start_workers(opts.concurrent_gens)
|
flushed_keys = redis.flush()
|
||||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||||
process_avg_gen_time_background_thread.daemon = True
|
|
||||||
process_avg_gen_time_background_thread.start()
|
|
||||||
MainBackgroundThread().start()
|
|
||||||
SemaphoreCheckerThread().start()
|
|
||||||
|
|
||||||
# Cache the initial stats
|
redis.set('backend_mode', opts.mode)
|
||||||
print('Loading backend stats...')
|
if config['http_host']:
|
||||||
generate_stats()
|
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||||
|
redis.set('http_host', http_host)
|
||||||
|
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||||
|
|
||||||
init_socketio(app)
|
if config['load_num_prompts']:
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
redis.set('proompts', get_number_of_rows('prompts'))
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
|
||||||
|
|
||||||
# This needs to be started after Flask is initalized
|
# Start background processes
|
||||||
stats_updater_thread = Thread(target=cache_stats)
|
start_workers(opts.concurrent_gens)
|
||||||
stats_updater_thread.daemon = True
|
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||||
stats_updater_thread.start()
|
process_avg_gen_time_background_thread.daemon = True
|
||||||
|
process_avg_gen_time_background_thread.start()
|
||||||
|
MainBackgroundThread().start()
|
||||||
|
SemaphoreCheckerThread().start()
|
||||||
|
|
||||||
|
# This needs to be started after Flask is initalized
|
||||||
|
stats_updater_thread = Thread(target=cache_stats)
|
||||||
|
stats_updater_thread.daemon = True
|
||||||
|
stats_updater_thread.start()
|
||||||
|
|
||||||
|
# Cache the initial stats
|
||||||
|
print('Loading backend stats...')
|
||||||
|
generate_stats()
|
||||||
|
|
||||||
|
|
||||||
# print(app.url_map)
|
# print(app.url_map)
|
||||||
|
@ -177,7 +177,7 @@ def home():
|
||||||
if not stats['online']:
|
if not stats['online']:
|
||||||
running_model = estimated_wait_sec = 'offline'
|
running_model = estimated_wait_sec = 'offline'
|
||||||
else:
|
else:
|
||||||
running_model = opts.running_model
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
active_gen_workers = get_active_gen_workers()
|
active_gen_workers = get_active_gen_workers()
|
||||||
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||||
|
|
Reference in New Issue