This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/v1/generate_stream.py

204 lines
8.1 KiB
Python

import json
import time
import traceback
import ujson
from flask import request
from redis import Redis
from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue
from ...config.global_config import GlobalConfig
from ...custom_redis import redis
from ...database.log_to_db import log_to_db
from ...logging import create_logger
from ...sock import sock
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
_logger = create_logger('GenerateStream')
@bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream')
def stream(model_name=None):
return 'This is a websocket endpoint.', 400
@sock.route('/v1/stream', bp=bp)
def stream_without_model(ws):
do_stream(ws, model_name=None)
@sock.route('/<model_name>/v1/stream', bp=bp)
def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
def do_stream(ws, model_name):
event_id = None
try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
if not GlobalConfig.get().enable_streaming:
return 'Streaming disabled', 403
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
# We have to do auth ourselves since the details are sent in the message.
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
_logger.debug(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': msg
}))
return
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
handler.parameters, _ = handler.get_parameters()
handler.prompt = input_prompt
handler.request_json_body = {
'prompt': handler.prompt,
**handler.parameters
}
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
r = handler.handle_ratelimited()
send_err_and_quit(r[0].data)
return
event_id = event.event_id
_, stream_name, error_msg = event.wait()
if error_msg:
_logger.error(f'Stream failed to start streaming: {error_msg}')
ws.close(reason=1014, message='Request Timeout')
return
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0' # The ID of the last entry we read.
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().redis_stream_timeout)
if not stream_data:
_logger.error(f"No message received in {GlobalConfig.get().redis_stream_timeout / 1000} seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = ujson.loads(item[b'data'])
if data['error']:
_logger.error(f'Encountered error while streaming: {data["error"]}')
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': data['new']
}))
message_num += 1
generated_text = generated_text + data['new']
elif data['completed']:
return
except:
send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
finally:
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
pass
if stream_name:
stream_redis.delete(stream_name)
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url
)
finally:
if event_id:
redis.publish(f'notifications:{event_id}', 'canceled')
try:
# Must close the connection or greenlets will complain.
ws.close()
except:
pass