get streaming working, remove /v2/
This commit is contained in:
parent
b10d22ca0d
commit
25ec56a5ef
|
@ -10,7 +10,7 @@ from llm_server.config.load import load_config, parse_backends
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.create import create_db
|
from llm_server.database.create import create_db
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.v2.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.threader import start_background
|
from llm_server.workers.threader import start_background
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
|
@ -47,9 +47,9 @@ def get_model_choices(regen: bool = False):
|
||||||
estimated_wait_sec = f"{estimated_wait_sec} seconds"
|
estimated_wait_sec = f"{estimated_wait_sec} seconds"
|
||||||
|
|
||||||
model_choices[model] = {
|
model_choices[model] = {
|
||||||
'client_api': f'https://{base_client_api}/v2/{model}',
|
'client_api': f'https://{base_client_api}/{model}',
|
||||||
'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None,
|
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||||
'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
||||||
'backend_count': len(b),
|
'backend_count': len(b),
|
||||||
'estimated_wait': estimated_wait_sec,
|
'estimated_wait': estimated_wait_sec,
|
||||||
'queued': proompters_in_queue,
|
'queued': proompters_in_queue,
|
||||||
|
@ -73,9 +73,9 @@ def get_model_choices(regen: bool = False):
|
||||||
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
|
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
|
||||||
|
|
||||||
default_backend_dict = {
|
default_backend_dict = {
|
||||||
'client_api': f'https://{base_client_api}/v2',
|
'client_api': f'https://{base_client_api}/v1',
|
||||||
'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None,
|
'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||||
'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled',
|
'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled',
|
||||||
'estimated_wait': default_estimated_wait_sec,
|
'estimated_wait': default_estimated_wait_sec,
|
||||||
'queued': default_proompters_in_queue,
|
'queued': default_proompters_in_queue,
|
||||||
'processing': default_active_gen_workers,
|
'processing': default_active_gen_workers,
|
||||||
|
|
|
@ -6,6 +6,7 @@ from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
|
||||||
|
|
||||||
def tokenize(prompt: str, backend_url: str) -> int:
|
def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
|
assert backend_url
|
||||||
if not prompt:
|
if not prompt:
|
||||||
# The tokenizers have issues when the prompt is None.
|
# The tokenizers have issues when the prompt is None.
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -1,18 +1,19 @@
|
||||||
from flask import Blueprint, jsonify
|
from flask import Blueprint
|
||||||
|
|
||||||
from llm_server.custom_redis import redis
|
from ..request_handler import before_request
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from ..server_error import handle_server_error
|
||||||
|
|
||||||
old_v1_bp = Blueprint('v1', __name__)
|
bp = Blueprint('v1', __name__)
|
||||||
|
|
||||||
|
|
||||||
@old_v1_bp.route('/', defaults={'path': ''}, methods=['GET', 'POST'])
|
@bp.before_request
|
||||||
@old_v1_bp.route('/<path:path>', methods=['GET', 'POST'])
|
def before_bp_request():
|
||||||
def fallback(path):
|
return before_request()
|
||||||
base_client_api = redis.get('base_client_api', dtype=str)
|
|
||||||
error_msg = f'The /v1/ endpoint has been depreciated. Please visit {base_client_api} for more information.\nAlso, you must enable "Relaxed API URLS" in settings.'
|
|
||||||
response_msg = format_sillytavern_err(error_msg, error_type='API')
|
@bp.errorhandler(500)
|
||||||
return jsonify({
|
def handle_error(e):
|
||||||
'results': [{'text': response_msg}],
|
return handle_server_error(e)
|
||||||
'result': f'Wrong API path, visit {base_client_api} for more info.'
|
|
||||||
}), 200 # return 200 so we don't trigger an error message in the client's ST
|
|
||||||
|
from . import generate, info, proxy, generate_stream
|
||||||
|
|
|
@ -1,26 +1,39 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
|
from . import bp
|
||||||
from ..helpers.http import require_api_key, validate_json
|
from ..helpers.http import require_api_key, validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...cluster.backend import get_a_cluster_backend
|
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.vllm import tokenize
|
from ...llm.vllm import tokenize
|
||||||
from ...sock import sock
|
from ...sock import sock
|
||||||
|
|
||||||
|
|
||||||
# TODO: have workers process streaming requests
|
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
||||||
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
# We solve this by splitting the routes
|
||||||
|
|
||||||
@sock.route('/api/v1/stream')
|
@bp.route('/stream')
|
||||||
def stream(ws):
|
def stream():
|
||||||
|
return 'This is a websocket endpoint.', 400
|
||||||
|
|
||||||
|
|
||||||
|
@sock.route('/stream', bp=bp)
|
||||||
|
def stream_without_model(ws):
|
||||||
|
do_stream(ws, model_name=None)
|
||||||
|
|
||||||
|
|
||||||
|
@sock.route('/<model_name>/v1/stream', bp=bp)
|
||||||
|
def stream_with_model(ws, model_name=None):
|
||||||
|
do_stream(ws, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def do_stream(ws, model_name):
|
||||||
def send_err_and_quit(quitting_err_msg):
|
def send_err_and_quit(quitting_err_msg):
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
|
@ -32,23 +45,33 @@ def stream(ws):
|
||||||
'message_num': 1
|
'message_num': 1
|
||||||
}))
|
}))
|
||||||
ws.close()
|
ws.close()
|
||||||
log_in_bg(quitting_err_msg, is_error=True)
|
log_prompt(ip=handler.client_ip,
|
||||||
|
token=handler.token,
|
||||||
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
prompt=input_prompt,
|
||||||
generated_tokens = tokenize(generated_text_bg)
|
response=quitting_err_msg,
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
|
gen_time=elapsed_time,
|
||||||
|
parameters=handler.parameters,
|
||||||
|
headers=r_headers,
|
||||||
|
backend_response_code=response_status_code,
|
||||||
|
request_url=r_url,
|
||||||
|
backend_url=handler.cluster_backend_info,
|
||||||
|
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||||
|
is_error=True
|
||||||
|
)
|
||||||
|
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
return 'Streaming is disabled', 401
|
return 'Streaming is disabled', 500
|
||||||
|
|
||||||
cluster_backend = None
|
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
message_num = 0
|
message_num = 0
|
||||||
|
|
||||||
while ws.connected:
|
while ws.connected:
|
||||||
message = ws.receive()
|
message = ws.receive()
|
||||||
request_valid_json, request_json_body = validate_json(message)
|
request_valid_json, request_json_body = validate_json(message)
|
||||||
|
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
|
ws.close()
|
||||||
return 'Invalid JSON', 400
|
return 'Invalid JSON', 400
|
||||||
else:
|
else:
|
||||||
if opts.mode != 'vllm':
|
if opts.mode != 'vllm':
|
||||||
|
@ -57,9 +80,10 @@ def stream(ws):
|
||||||
|
|
||||||
auth_failure = require_api_key(request_json_body)
|
auth_failure = require_api_key(request_json_body)
|
||||||
if auth_failure:
|
if auth_failure:
|
||||||
|
ws.close()
|
||||||
return auth_failure
|
return auth_failure
|
||||||
|
|
||||||
handler = OobaRequestHandler(request, request_json_body)
|
handler = OobaRequestHandler(request, model_name, request_json_body)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
input_prompt = request_json_body['prompt']
|
input_prompt = request_json_body['prompt']
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
|
@ -84,15 +108,14 @@ def stream(ws):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add a dummy event to the queue and wait for it to reach a worker
|
# Add a dummy event to the queue and wait for it to reach a worker
|
||||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
|
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||||
if not event:
|
if not event:
|
||||||
r, _ = handler.handle_ratelimited()
|
r, _ = handler.handle_ratelimited()
|
||||||
err_msg = r.json['results'][0]['text']
|
err_msg = r.json['results'][0]['text']
|
||||||
send_err_and_quit(err_msg)
|
send_err_and_quit(err_msg)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
cluster_backend = get_a_cluster_backend()
|
response = generator(llm_request, handler.backend_url)
|
||||||
response = generator(llm_request, cluster_backend)
|
|
||||||
if not response:
|
if not response:
|
||||||
error_msg = 'Failed to reach backend while streaming.'
|
error_msg = 'Failed to reach backend while streaming.'
|
||||||
print('Streaming failed:', error_msg)
|
print('Streaming failed:', error_msg)
|
||||||
|
@ -134,10 +157,25 @@ def stream(ws):
|
||||||
# The has client closed the stream.
|
# The has client closed the stream.
|
||||||
if request:
|
if request:
|
||||||
request.close()
|
request.close()
|
||||||
|
try:
|
||||||
ws.close()
|
ws.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
log_prompt(ip=handler.client_ip,
|
||||||
|
token=handler.token,
|
||||||
|
prompt=input_prompt,
|
||||||
|
response=generated_text,
|
||||||
|
gen_time=elapsed_time,
|
||||||
|
parameters=handler.parameters,
|
||||||
|
headers=r_headers,
|
||||||
|
backend_response_code=response_status_code,
|
||||||
|
request_url=r_url,
|
||||||
|
backend_url=handler.backend_url,
|
||||||
|
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
message_num += 1
|
message_num += 1
|
||||||
|
@ -149,7 +187,19 @@ def stream(ws):
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
log_prompt(ip=handler.client_ip,
|
||||||
|
token=handler.token,
|
||||||
|
prompt=input_prompt,
|
||||||
|
response=generated_text,
|
||||||
|
gen_time=elapsed_time,
|
||||||
|
parameters=handler.parameters,
|
||||||
|
headers=r_headers,
|
||||||
|
backend_response_code=response_status_code,
|
||||||
|
request_url=r_url,
|
||||||
|
backend_url=handler.backend_url,
|
||||||
|
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||||
|
is_error=not response
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||||
|
@ -161,12 +211,24 @@ def stream(ws):
|
||||||
if request:
|
if request:
|
||||||
request.close()
|
request.close()
|
||||||
ws.close()
|
ws.close()
|
||||||
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
|
log_prompt(ip=handler.client_ip,
|
||||||
|
token=handler.token,
|
||||||
|
prompt=input_prompt,
|
||||||
|
response=generated_text,
|
||||||
|
gen_time=None,
|
||||||
|
parameters=handler.parameters,
|
||||||
|
headers=r_headers,
|
||||||
|
backend_response_code=response_status_code,
|
||||||
|
request_url=r_url,
|
||||||
|
backend_url=handler.backend_url,
|
||||||
|
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||||
|
is_error=True
|
||||||
|
)
|
||||||
return
|
return
|
||||||
finally:
|
finally:
|
||||||
# The worker incremented it, we'll decrement it.
|
# The worker incremented it, we'll decrement it.
|
||||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||||
decr_active_workers()
|
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||||
try:
|
try:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
|
@ -176,5 +238,19 @@ def stream(ws):
|
||||||
# The client closed the stream.
|
# The client closed the stream.
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
log_prompt(ip=handler.client_ip,
|
||||||
|
token=handler.token,
|
||||||
|
prompt=input_prompt,
|
||||||
|
response=generated_text,
|
||||||
|
gen_time=elapsed_time,
|
||||||
|
parameters=handler.parameters,
|
||||||
|
headers=r_headers,
|
||||||
|
backend_response_code=response_status_code,
|
||||||
|
request_url=r_url,
|
||||||
|
backend_url=handler.backend_url,
|
||||||
|
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||||
|
)
|
||||||
|
try:
|
||||||
ws.close() # this is important if we encountered and error and exited early.
|
ws.close() # this is important if we encountered and error and exited early.
|
||||||
|
except:
|
||||||
|
pass
|
|
@ -1,19 +0,0 @@
|
||||||
from flask import Blueprint
|
|
||||||
|
|
||||||
from ..request_handler import before_request
|
|
||||||
from ..server_error import handle_server_error
|
|
||||||
|
|
||||||
bp = Blueprint('v2', __name__)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.before_request
|
|
||||||
def before_bp_request():
|
|
||||||
return before_request()
|
|
||||||
|
|
||||||
|
|
||||||
@bp.errorhandler(500)
|
|
||||||
def handle_error(e):
|
|
||||||
return handle_server_error(e)
|
|
||||||
|
|
||||||
|
|
||||||
from . import generate, info, proxy, generate_stream
|
|
|
@ -4,7 +4,7 @@ from threading import Thread
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.stores import redis_running_models
|
from llm_server.cluster.stores import redis_running_models
|
||||||
from llm_server.cluster.worker import cluster_worker
|
from llm_server.cluster.worker import cluster_worker
|
||||||
from llm_server.routes.v2.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.inferencer import start_workers
|
from llm_server.workers.inferencer import start_workers
|
||||||
from llm_server.workers.mainer import main_background_thread
|
from llm_server.workers.mainer import main_background_thread
|
||||||
from llm_server.workers.moderator import start_moderation_workers
|
from llm_server.workers.moderator import start_moderation_workers
|
||||||
|
|
12
server.py
12
server.py
|
@ -21,8 +21,7 @@ from llm_server.database.create import create_db
|
||||||
from llm_server.pre_fork import server_startup
|
from llm_server.pre_fork import server_startup
|
||||||
from llm_server.routes.openai import openai_bp
|
from llm_server.routes.openai import openai_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import old_v1_bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v2 import bp
|
|
||||||
from llm_server.sock import init_socketio
|
from llm_server.sock import init_socketio
|
||||||
|
|
||||||
# TODO: per-backend workers
|
# TODO: per-backend workers
|
||||||
|
@ -66,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.custom_redis import flask_cache
|
from llm_server.custom_redis import flask_cache
|
||||||
from llm_server.llm import redis
|
from llm_server.llm import redis
|
||||||
from llm_server.routes.v2.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
init_socketio(app)
|
init_socketio(app)
|
||||||
app.register_blueprint(bp, url_prefix='/api/v2/')
|
app.register_blueprint(bp, url_prefix='/api/')
|
||||||
app.register_blueprint(old_v1_bp, url_prefix='/api/v1/')
|
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
flask_cache.init_app(app)
|
flask_cache.init_app(app)
|
||||||
flask_cache.clear()
|
flask_cache.clear()
|
||||||
|
@ -133,8 +131,8 @@ def home():
|
||||||
default_active_gen_workers=default_backend_info['processing'],
|
default_active_gen_workers=default_backend_info['processing'],
|
||||||
default_proompters_in_queue=default_backend_info['queued'],
|
default_proompters_in_queue=default_backend_info['queued'],
|
||||||
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
||||||
client_api=f'https://{base_client_api}/v2',
|
client_api=f'https://{base_client_api}/v1',
|
||||||
ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled',
|
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
|
||||||
default_estimated_wait=default_estimated_wait_sec,
|
default_estimated_wait=default_estimated_wait_sec,
|
||||||
mode_name=mode_ui_names[opts.mode][0],
|
mode_name=mode_ui_names[opts.mode][0],
|
||||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||||
|
|
Reference in New Issue