From 25ec56a5efb22846246c8d90a40db2cfcf546bce Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 1 Oct 2023 00:20:00 -0600 Subject: [PATCH] get streaming working, remove /v2/ --- daemon.py | 2 +- llm_server/cluster/model_choices.py | 12 +- llm_server/llm/vllm/tokenize.py | 1 + llm_server/routes/v1/__init__.py | 29 ++-- llm_server/routes/{v2 => v1}/generate.py | 0 .../routes/{v2 => v1}/generate_stats.py | 0 .../routes/{v2 => v1}/generate_stream.py | 124 ++++++++++++++---- llm_server/routes/{v2 => v1}/info.py | 0 llm_server/routes/{v2 => v1}/proxy.py | 0 llm_server/routes/v2/__init__.py | 19 --- llm_server/workers/threader.py | 2 +- server.py | 12 +- 12 files changed, 129 insertions(+), 72 deletions(-) rename llm_server/routes/{v2 => v1}/generate.py (100%) rename llm_server/routes/{v2 => v1}/generate_stats.py (100%) rename llm_server/routes/{v2 => v1}/generate_stream.py (56%) rename llm_server/routes/{v2 => v1}/info.py (100%) rename llm_server/routes/{v2 => v1}/proxy.py (100%) delete mode 100644 llm_server/routes/v2/__init__.py diff --git a/daemon.py b/daemon.py index aac2657..0fa3601 100644 --- a/daemon.py +++ b/daemon.py @@ -10,7 +10,7 @@ from llm_server.config.load import load_config, parse_backends from llm_server.custom_redis import redis from llm_server.database.create import create_db from llm_server.routes.queue import priority_queue -from llm_server.routes.v2.generate_stats import generate_stats +from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.threader import start_background script_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/llm_server/cluster/model_choices.py b/llm_server/cluster/model_choices.py index ec78e2f..f8383fe 100644 --- a/llm_server/cluster/model_choices.py +++ b/llm_server/cluster/model_choices.py @@ -47,9 +47,9 @@ def get_model_choices(regen: bool = False): estimated_wait_sec = f"{estimated_wait_sec} seconds" model_choices[model] = { - 'client_api': f'https://{base_client_api}/v2/{model}', - 'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None, - 'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled', + 'client_api': f'https://{base_client_api}/{model}', + 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled', 'backend_count': len(b), 'estimated_wait': estimated_wait_sec, 'queued': proompters_in_queue, @@ -73,9 +73,9 @@ def get_model_choices(regen: bool = False): default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers) default_backend_dict = { - 'client_api': f'https://{base_client_api}/v2', - 'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None, - 'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled', + 'client_api': f'https://{base_client_api}/v1', + 'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled', 'estimated_wait': default_estimated_wait_sec, 'queued': default_proompters_in_queue, 'processing': default_active_gen_workers, diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 747a8b8..5cad1a4 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -6,6 +6,7 @@ from llm_server.cluster.cluster_config import cluster_config def tokenize(prompt: str, backend_url: str) -> int: + assert backend_url if not prompt: # The tokenizers have issues when the prompt is None. return 0 diff --git a/llm_server/routes/v1/__init__.py b/llm_server/routes/v1/__init__.py index a52cb2e..123683d 100644 --- a/llm_server/routes/v1/__init__.py +++ b/llm_server/routes/v1/__init__.py @@ -1,18 +1,19 @@ -from flask import Blueprint, jsonify +from flask import Blueprint -from llm_server.custom_redis import redis -from llm_server.routes.helpers.client import format_sillytavern_err +from ..request_handler import before_request +from ..server_error import handle_server_error -old_v1_bp = Blueprint('v1', __name__) +bp = Blueprint('v1', __name__) -@old_v1_bp.route('/', defaults={'path': ''}, methods=['GET', 'POST']) -@old_v1_bp.route('/', methods=['GET', 'POST']) -def fallback(path): - base_client_api = redis.get('base_client_api', dtype=str) - error_msg = f'The /v1/ endpoint has been depreciated. Please visit {base_client_api} for more information.\nAlso, you must enable "Relaxed API URLS" in settings.' - response_msg = format_sillytavern_err(error_msg, error_type='API') - return jsonify({ - 'results': [{'text': response_msg}], - 'result': f'Wrong API path, visit {base_client_api} for more info.' - }), 200 # return 200 so we don't trigger an error message in the client's ST +@bp.before_request +def before_bp_request(): + return before_request() + + +@bp.errorhandler(500) +def handle_error(e): + return handle_server_error(e) + + +from . import generate, info, proxy, generate_stream diff --git a/llm_server/routes/v2/generate.py b/llm_server/routes/v1/generate.py similarity index 100% rename from llm_server/routes/v2/generate.py rename to llm_server/routes/v1/generate.py diff --git a/llm_server/routes/v2/generate_stats.py b/llm_server/routes/v1/generate_stats.py similarity index 100% rename from llm_server/routes/v2/generate_stats.py rename to llm_server/routes/v1/generate_stats.py diff --git a/llm_server/routes/v2/generate_stream.py b/llm_server/routes/v1/generate_stream.py similarity index 56% rename from llm_server/routes/v2/generate_stream.py rename to llm_server/routes/v1/generate_stream.py index e3aeeb0..c0a0927 100644 --- a/llm_server/routes/v2/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,26 +1,39 @@ import json import time import traceback -from typing import Union from flask import request +from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts -from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize from ...sock import sock -# TODO: have workers process streaming requests -# TODO: make sure to log the token as well (seems to be missing in the DB right now) +# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. +# We solve this by splitting the routes -@sock.route('/api/v1/stream') -def stream(ws): +@bp.route('/stream') +def stream(): + return 'This is a websocket endpoint.', 400 + + +@sock.route('/stream', bp=bp) +def stream_without_model(ws): + do_stream(ws, model_name=None) + + +@sock.route('//v1/stream', bp=bp) +def stream_with_model(ws, model_name=None): + do_stream(ws, model_name) + + +def do_stream(ws, model_name): def send_err_and_quit(quitting_err_msg): ws.send(json.dumps({ 'event': 'text_stream', @@ -32,23 +45,33 @@ def stream(ws): 'message_num': 1 })) ws.close() - log_in_bg(quitting_err_msg, is_error=True) - - def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): - generated_tokens = tokenize(generated_text_bg) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error) + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=quitting_err_msg, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.cluster_backend_info, + response_tokens=tokenize(generated_text, handler.backend_url), + is_error=True + ) if not opts.enable_streaming: - return 'Streaming is disabled', 401 + return 'Streaming is disabled', 500 - cluster_backend = None r_headers = dict(request.headers) r_url = request.url message_num = 0 + while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) + if not request_valid_json or not request_json_body.get('prompt'): + ws.close() return 'Invalid JSON', 400 else: if opts.mode != 'vllm': @@ -57,9 +80,10 @@ def stream(ws): auth_failure = require_api_key(request_json_body) if auth_failure: + ws.close() return auth_failure - handler = OobaRequestHandler(request, request_json_body) + handler = OobaRequestHandler(request, model_name, request_json_body) generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 @@ -84,15 +108,14 @@ def stream(ws): } # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority) + event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) if not event: r, _ = handler.handle_ratelimited() err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return try: - cluster_backend = get_a_cluster_backend() - response = generator(llm_request, cluster_backend) + response = generator(llm_request, handler.backend_url) if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) @@ -134,10 +157,25 @@ def stream(ws): # The has client closed the stream. if request: request.close() - ws.close() + try: + ws.close() + except: + pass end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=tokenize(generated_text, handler.backend_url) + ) + return message_num += 1 @@ -149,7 +187,19 @@ def stream(ws): end_time = time.time() elapsed_time = end_time - start_time - log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=tokenize(generated_text, handler.backend_url), + is_error=not response + ) except: traceback.print_exc() generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] @@ -161,12 +211,24 @@ def stream(ws): if request: request.close() ws.close() - log_in_bg(generated_text, is_error=True, status_code=response_status_code) + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=None, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=tokenize(generated_text, handler.backend_url), + is_error=True + ) return finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers() + decr_active_workers(handler.selected_model, handler.backend_url) try: ws.send(json.dumps({ 'event': 'stream_end', @@ -176,5 +238,19 @@ def stream(ws): # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) - ws.close() # this is important if we encountered and error and exited early. + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=tokenize(generated_text, handler.backend_url) + ) + try: + ws.close() # this is important if we encountered and error and exited early. + except: + pass diff --git a/llm_server/routes/v2/info.py b/llm_server/routes/v1/info.py similarity index 100% rename from llm_server/routes/v2/info.py rename to llm_server/routes/v1/info.py diff --git a/llm_server/routes/v2/proxy.py b/llm_server/routes/v1/proxy.py similarity index 100% rename from llm_server/routes/v2/proxy.py rename to llm_server/routes/v1/proxy.py diff --git a/llm_server/routes/v2/__init__.py b/llm_server/routes/v2/__init__.py deleted file mode 100644 index 1860473..0000000 --- a/llm_server/routes/v2/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from flask import Blueprint - -from ..request_handler import before_request -from ..server_error import handle_server_error - -bp = Blueprint('v2', __name__) - - -@bp.before_request -def before_bp_request(): - return before_request() - - -@bp.errorhandler(500) -def handle_error(e): - return handle_server_error(e) - - -from . import generate, info, proxy, generate_stream diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index 89a6770..0c82559 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -4,7 +4,7 @@ from threading import Thread from llm_server import opts from llm_server.cluster.stores import redis_running_models from llm_server.cluster.worker import cluster_worker -from llm_server.routes.v2.generate_stats import generate_stats +from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.inferencer import start_workers from llm_server.workers.mainer import main_background_thread from llm_server.workers.moderator import start_moderation_workers diff --git a/server.py b/server.py index 7cb5a88..0eba490 100644 --- a/server.py +++ b/server.py @@ -21,8 +21,7 @@ from llm_server.database.create import create_db from llm_server.pre_fork import server_startup from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error -from llm_server.routes.v1 import old_v1_bp -from llm_server.routes.v2 import bp +from llm_server.routes.v1 import bp from llm_server.sock import init_socketio # TODO: per-backend workers @@ -66,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.custom_redis import flask_cache from llm_server.llm import redis -from llm_server.routes.v2.generate_stats import generate_stats +from llm_server.routes.v1.generate_stats import generate_stats app = Flask(__name__) init_socketio(app) -app.register_blueprint(bp, url_prefix='/api/v2/') -app.register_blueprint(old_v1_bp, url_prefix='/api/v1/') +app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') flask_cache.init_app(app) flask_cache.clear() @@ -133,8 +131,8 @@ def home(): default_active_gen_workers=default_backend_info['processing'], default_proompters_in_queue=default_backend_info['queued'], current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model, - client_api=f'https://{base_client_api}/v2', - ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled', + client_api=f'https://{base_client_api}/v1', + ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled', default_estimated_wait=default_estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1],