get streaming working, remove /v2/

This commit is contained in:
Cyberes 2023-10-01 00:20:00 -06:00
parent b10d22ca0d
commit 25ec56a5ef
12 changed files with 129 additions and 72 deletions

View File

@ -10,7 +10,7 @@ from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.routes.queue import priority_queue
from llm_server.routes.v2.generate_stats import generate_stats
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))

View File

@ -47,9 +47,9 @@ def get_model_choices(regen: bool = False):
estimated_wait_sec = f"{estimated_wait_sec} seconds"
model_choices[model] = {
'client_api': f'https://{base_client_api}/v2/{model}',
'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled',
'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
@ -73,9 +73,9 @@ def get_model_choices(regen: bool = False):
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
default_backend_dict = {
'client_api': f'https://{base_client_api}/v2',
'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled',
'client_api': f'https://{base_client_api}/v1',
'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled',
'estimated_wait': default_estimated_wait_sec,
'queued': default_proompters_in_queue,
'processing': default_active_gen_workers,

View File

@ -6,6 +6,7 @@ from llm_server.cluster.cluster_config import cluster_config
def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0

View File

@ -1,18 +1,19 @@
from flask import Blueprint, jsonify
from flask import Blueprint
from llm_server.custom_redis import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from ..request_handler import before_request
from ..server_error import handle_server_error
old_v1_bp = Blueprint('v1', __name__)
bp = Blueprint('v1', __name__)
@old_v1_bp.route('/', defaults={'path': ''}, methods=['GET', 'POST'])
@old_v1_bp.route('/<path:path>', methods=['GET', 'POST'])
def fallback(path):
base_client_api = redis.get('base_client_api', dtype=str)
error_msg = f'The /v1/ endpoint has been depreciated. Please visit {base_client_api} for more information.\nAlso, you must enable "Relaxed API URLS" in settings.'
response_msg = format_sillytavern_err(error_msg, error_type='API')
return jsonify({
'results': [{'text': response_msg}],
'result': f'Wrong API path, visit {base_client_api} for more info.'
}), 200 # return 200 so we don't trigger an error message in the client's ST
@bp.before_request
def before_bp_request():
return before_request()
@bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
from . import generate, info, proxy, generate_stream

View File

@ -1,26 +1,39 @@
import json
import time
import traceback
from typing import Union
from flask import request
from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...sock import sock
# TODO: have workers process streaming requests
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
@sock.route('/api/v1/stream')
def stream(ws):
@bp.route('/stream')
def stream():
return 'This is a websocket endpoint.', 400
@sock.route('/stream', bp=bp)
def stream_without_model(ws):
do_stream(ws, model_name=None)
@sock.route('/<model_name>/v1/stream', bp=bp)
def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
def do_stream(ws, model_name):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
@ -32,23 +45,33 @@ def stream(ws):
'message_num': 1
}))
ws.close()
log_in_bg(quitting_err_msg, is_error=True)
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=tokenize(generated_text, handler.backend_url),
is_error=True
)
if not opts.enable_streaming:
return 'Streaming is disabled', 401
return 'Streaming is disabled', 500
cluster_backend = None
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
ws.close()
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
@ -57,9 +80,10 @@ def stream(ws):
auth_failure = require_api_key(request_json_body)
if auth_failure:
ws.close()
return auth_failure
handler = OobaRequestHandler(request, request_json_body)
handler = OobaRequestHandler(request, model_name, request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
@ -84,15 +108,14 @@ def stream(ws):
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
cluster_backend = get_a_cluster_backend()
response = generator(llm_request, cluster_backend)
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
@ -134,10 +157,25 @@ def stream(ws):
# The has client closed the stream.
if request:
request.close()
try:
ws.close()
except:
pass
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url)
)
return
message_num += 1
@ -149,7 +187,19 @@ def stream(ws):
end_time = time.time()
elapsed_time = end_time - start_time
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url),
is_error=not response
)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
@ -161,12 +211,24 @@ def stream(ws):
if request:
request.close()
ws.close()
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url),
is_error=True
)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers()
decr_active_workers(handler.selected_model, handler.backend_url)
try:
ws.send(json.dumps({
'event': 'stream_end',
@ -176,5 +238,19 @@ def stream(ws):
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url)
)
try:
ws.close() # this is important if we encountered and error and exited early.
except:
pass

View File

@ -1,19 +0,0 @@
from flask import Blueprint
from ..request_handler import before_request
from ..server_error import handle_server_error
bp = Blueprint('v2', __name__)
@bp.before_request
def before_bp_request():
return before_request()
@bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
from . import generate, info, proxy, generate_stream

View File

@ -4,7 +4,7 @@ from threading import Thread
from llm_server import opts
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v2.generate_stats import generate_stats
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers

View File

@ -21,8 +21,7 @@ from llm_server.database.create import create_db
from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import old_v1_bp
from llm_server.routes.v2 import bp
from llm_server.routes.v1 import bp
from llm_server.sock import init_socketio
# TODO: per-backend workers
@ -66,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.custom_redis import flask_cache
from llm_server.llm import redis
from llm_server.routes.v2.generate_stats import generate_stats
from llm_server.routes.v1.generate_stats import generate_stats
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v2/')
app.register_blueprint(old_v1_bp, url_prefix='/api/v1/')
app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
flask_cache.init_app(app)
flask_cache.clear()
@ -133,8 +131,8 @@ def home():
default_active_gen_workers=default_backend_info['processing'],
default_proompters_in_queue=default_backend_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}/v2',
ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled',
client_api=f'https://{base_client_api}/v1',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],