convert to gunicorn
This commit is contained in:
parent
0eb901cb52
commit
e0af2ea9c5
|
@ -0,0 +1,13 @@
|
|||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import server
|
||||
|
||||
|
||||
def on_starting(s):
|
||||
server.pre_fork(s)
|
||||
print('Startup complete!')
|
|
@ -6,6 +6,7 @@ import llm_server
|
|||
from llm_server import opts
|
||||
from llm_server.database.conn import db_pool
|
||||
from llm_server.llm.vllm import tokenize
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
|
@ -33,6 +34,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
if token:
|
||||
increment_token_uses(token)
|
||||
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
|
@ -42,7 +45,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import requests
|
|||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
|
@ -49,13 +50,14 @@ def transform_to_text(json_request, api_response):
|
|||
|
||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
"id": str(uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": opts.running_model,
|
||||
"model": running_model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||
|
||||
running_model = 'none'
|
||||
running_model = 'ERROR'
|
||||
concurrent_gens = 3
|
||||
mode = 'oobabooga'
|
||||
backend_url = None
|
||||
|
|
|
@ -6,6 +6,7 @@ import traceback
|
|||
from flask import Response, jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
||||
|
@ -50,7 +51,7 @@ def openai_chat_completions():
|
|||
response = generator(msg_to_backend)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
|
|
|
@ -4,7 +4,7 @@ import requests
|
|||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, cache
|
||||
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...helpers import jsonify_pretty
|
||||
|
@ -22,6 +22,7 @@ def openai_list_models():
|
|||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
oai = fetch_openai_models()
|
||||
r = []
|
||||
if opts.openai_expose_our_model:
|
||||
|
@ -29,13 +30,13 @@ def openai_list_models():
|
|||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": opts.running_model,
|
||||
"id": running_model,
|
||||
"object": "model",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"owned_by": opts.llm_middleware_name,
|
||||
"permission": [
|
||||
{
|
||||
"id": opts.running_model,
|
||||
"id": running_model,
|
||||
"object": "model_permission",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"allow_create_engine": False,
|
||||
|
|
|
@ -14,6 +14,7 @@ from flask import jsonify
|
|||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
@ -157,11 +158,13 @@ def build_openai_response(prompt, response, model=None):
|
|||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": opts.running_model if opts.openai_expose_our_model else model,
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
|
|
|
@ -46,7 +46,7 @@ def generate_stats(regen: bool = False):
|
|||
online = False
|
||||
else:
|
||||
online = True
|
||||
opts.running_model = model_name
|
||||
redis.set('running_model', model_name)
|
||||
|
||||
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
||||
# if len(t) == 0:
|
||||
|
|
|
@ -25,35 +25,36 @@ class MainBackgroundThread(Thread):
|
|||
|
||||
def run(self):
|
||||
while True:
|
||||
# TODO: unify this
|
||||
if opts.mode == 'oobabooga':
|
||||
model, err = get_running_model()
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
opts.running_model = model
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'vllm':
|
||||
model, err = get_running_model()
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
opts.running_model = model
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec:
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
|
|
56
server.py
56
server.py
|
@ -15,6 +15,8 @@ from llm_server.database.database import get_number_of_rows
|
|||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
|
||||
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
||||
# TODO: support turbo-instruct on openai endpoint
|
||||
|
@ -22,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
|
|||
# TODO: validate system tokens before excluding them
|
||||
# TODO: make sure prompts are logged even when the user cancels generation
|
||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||
# TODO: use the current estimated wait time for ratelimit headers on openai
|
||||
|
||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||
# TODO: unify logging thread in a function and use async/await instead
|
||||
|
@ -43,13 +46,18 @@ from llm_server.llm.vllm.info import vllm_info
|
|||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
app = Flask(__name__)
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
|
@ -73,9 +81,6 @@ if config['mode'] not in ['oobabooga', 'vllm']:
|
|||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
# TODO: this is a MESS
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
|
@ -108,23 +113,12 @@ if opts.openai_expose_our_model and not opts.openai_api_key:
|
|||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
sys.exit(1)
|
||||
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
redis.set('http_host', http_host)
|
||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||
print('Set host to', redis.get('http_host', str))
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
if config['average_generation_time_mode'] not in ['database', 'minute']:
|
||||
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
|
||||
sys.exit(1)
|
||||
|
@ -138,9 +132,19 @@ elif opts.mode == 'vllm':
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
|
||||
def pre_fork(server):
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
redis.set('http_host', http_host)
|
||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
@ -150,19 +154,15 @@ process_avg_gen_time_background_thread.start()
|
|||
MainBackgroundThread().start()
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
|
||||
# This needs to be started after Flask is initalized
|
||||
stats_updater_thread = Thread(target=cache_stats)
|
||||
stats_updater_thread.daemon = True
|
||||
stats_updater_thread.start()
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
|
||||
|
||||
# print(app.url_map)
|
||||
|
||||
|
@ -177,7 +177,7 @@ def home():
|
|||
if not stats['online']:
|
||||
running_model = estimated_wait_sec = 'offline'
|
||||
else:
|
||||
running_model = opts.running_model
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||
|
|
Reference in New Issue