get streaming working, remove /v2/
This commit is contained in:
parent
b10d22ca0d
commit
25ec56a5ef
|
@ -10,7 +10,7 @@ from llm_server.config.load import load_config, parse_backends
|
|||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.v2.generate_stats import generate_stats
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.threader import start_background
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
|
|
@ -47,9 +47,9 @@ def get_model_choices(regen: bool = False):
|
|||
estimated_wait_sec = f"{estimated_wait_sec} seconds"
|
||||
|
||||
model_choices[model] = {
|
||||
'client_api': f'https://{base_client_api}/v2/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'client_api': f'https://{base_client_api}/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'backend_count': len(b),
|
||||
'estimated_wait': estimated_wait_sec,
|
||||
'queued': proompters_in_queue,
|
||||
|
@ -73,9 +73,9 @@ def get_model_choices(regen: bool = False):
|
|||
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
|
||||
|
||||
default_backend_dict = {
|
||||
'client_api': f'https://{base_client_api}/v2',
|
||||
'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'client_api': f'https://{base_client_api}/v1',
|
||||
'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'estimated_wait': default_estimated_wait_sec,
|
||||
'queued': default_proompters_in_queue,
|
||||
'processing': default_active_gen_workers,
|
||||
|
|
|
@ -6,6 +6,7 @@ from llm_server.cluster.cluster_config import cluster_config
|
|||
|
||||
|
||||
def tokenize(prompt: str, backend_url: str) -> int:
|
||||
assert backend_url
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
from flask import Blueprint, jsonify
|
||||
from flask import Blueprint
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from ..request_handler import before_request
|
||||
from ..server_error import handle_server_error
|
||||
|
||||
old_v1_bp = Blueprint('v1', __name__)
|
||||
bp = Blueprint('v1', __name__)
|
||||
|
||||
|
||||
@old_v1_bp.route('/', defaults={'path': ''}, methods=['GET', 'POST'])
|
||||
@old_v1_bp.route('/<path:path>', methods=['GET', 'POST'])
|
||||
def fallback(path):
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
error_msg = f'The /v1/ endpoint has been depreciated. Please visit {base_client_api} for more information.\nAlso, you must enable "Relaxed API URLS" in settings.'
|
||||
response_msg = format_sillytavern_err(error_msg, error_type='API')
|
||||
return jsonify({
|
||||
'results': [{'text': response_msg}],
|
||||
'result': f'Wrong API path, visit {base_client_api} for more info.'
|
||||
}), 200 # return 200 so we don't trigger an error message in the client's ST
|
||||
@bp.before_request
|
||||
def before_bp_request():
|
||||
return before_request()
|
||||
|
||||
|
||||
@bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
|
||||
|
||||
from . import generate, info, proxy, generate_stream
|
||||
|
|
|
@ -1,26 +1,39 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from flask import request
|
||||
|
||||
from . import bp
|
||||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
# TODO: have workers process streaming requests
|
||||
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
||||
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
||||
# We solve this by splitting the routes
|
||||
|
||||
@sock.route('/api/v1/stream')
|
||||
def stream(ws):
|
||||
@bp.route('/stream')
|
||||
def stream():
|
||||
return 'This is a websocket endpoint.', 400
|
||||
|
||||
|
||||
@sock.route('/stream', bp=bp)
|
||||
def stream_without_model(ws):
|
||||
do_stream(ws, model_name=None)
|
||||
|
||||
|
||||
@sock.route('/<model_name>/v1/stream', bp=bp)
|
||||
def stream_with_model(ws, model_name=None):
|
||||
do_stream(ws, model_name)
|
||||
|
||||
|
||||
def do_stream(ws, model_name):
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
|
@ -32,23 +45,33 @@ def stream(ws):
|
|||
'message_num': 1
|
||||
}))
|
||||
ws.close()
|
||||
log_in_bg(quitting_err_msg, is_error=True)
|
||||
|
||||
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
||||
generated_tokens = tokenize(generated_text_bg)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
is_error=True
|
||||
)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 401
|
||||
return 'Streaming is disabled', 500
|
||||
|
||||
cluster_backend = None
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
ws.close()
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
if opts.mode != 'vllm':
|
||||
|
@ -57,9 +80,10 @@ def stream(ws):
|
|||
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
ws.close()
|
||||
return auth_failure
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
handler = OobaRequestHandler(request, model_name, request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
|
@ -84,15 +108,14 @@ def stream(ws):
|
|||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
try:
|
||||
cluster_backend = get_a_cluster_backend()
|
||||
response = generator(llm_request, cluster_backend)
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
|
@ -134,10 +157,25 @@ def stream(ws):
|
|||
# The has client closed the stream.
|
||||
if request:
|
||||
request.close()
|
||||
ws.close()
|
||||
try:
|
||||
ws.close()
|
||||
except:
|
||||
pass
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
|
@ -149,7 +187,19 @@ def stream(ws):
|
|||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
is_error=not response
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||
|
@ -161,12 +211,24 @@ def stream(ws):
|
|||
if request:
|
||||
request.close()
|
||||
ws.close()
|
||||
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
is_error=True
|
||||
)
|
||||
return
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers()
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
|
@ -176,5 +238,19 @@ def stream(ws):
|
|||
# The client closed the stream.
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||
)
|
||||
try:
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
except:
|
||||
pass
|
|
@ -1,19 +0,0 @@
|
|||
from flask import Blueprint
|
||||
|
||||
from ..request_handler import before_request
|
||||
from ..server_error import handle_server_error
|
||||
|
||||
bp = Blueprint('v2', __name__)
|
||||
|
||||
|
||||
@bp.before_request
|
||||
def before_bp_request():
|
||||
return before_request()
|
||||
|
||||
|
||||
@bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
|
||||
|
||||
from . import generate, info, proxy, generate_stream
|
|
@ -4,7 +4,7 @@ from threading import Thread
|
|||
from llm_server import opts
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.routes.v2.generate_stats import generate_stats
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
from llm_server.workers.mainer import main_background_thread
|
||||
from llm_server.workers.moderator import start_moderation_workers
|
||||
|
|
12
server.py
12
server.py
|
@ -21,8 +21,7 @@ from llm_server.database.create import create_db
|
|||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import old_v1_bp
|
||||
from llm_server.routes.v2 import bp
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.sock import init_socketio
|
||||
|
||||
# TODO: per-backend workers
|
||||
|
@ -66,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api
|
|||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.v2.generate_stats import generate_stats
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
app = Flask(__name__)
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v2/')
|
||||
app.register_blueprint(old_v1_bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(bp, url_prefix='/api/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
@ -133,8 +131,8 @@ def home():
|
|||
default_active_gen_workers=default_backend_info['processing'],
|
||||
default_proompters_in_queue=default_backend_info['queued'],
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
||||
client_api=f'https://{base_client_api}/v2',
|
||||
ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled',
|
||||
client_api=f'https://{base_client_api}/v1',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
|
||||
default_estimated_wait=default_estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
|
|
Reference in New Issue