200 lines
7.8 KiB
Python
200 lines
7.8 KiB
Python
import json
|
|
import time
|
|
import traceback
|
|
|
|
import ujson
|
|
from flask import request
|
|
from redis import Redis
|
|
|
|
from . import bp
|
|
from ..helpers.http import require_api_key, validate_json
|
|
from ..ooba_request_handler import OobaRequestHandler
|
|
from ..queue import priority_queue
|
|
from ... import opts
|
|
from ...custom_redis import redis
|
|
from ...database.log_to_db import log_to_db
|
|
from ...sock import sock
|
|
|
|
|
|
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
|
# We solve this by splitting the routes
|
|
|
|
@bp.route('/v1/stream')
|
|
@bp.route('/<model_name>/v1/stream')
|
|
def stream(model_name=None):
|
|
return 'This is a websocket endpoint.', 400
|
|
|
|
|
|
@sock.route('/v1/stream', bp=bp)
|
|
def stream_without_model(ws):
|
|
do_stream(ws, model_name=None)
|
|
|
|
|
|
@sock.route('/<model_name>/v1/stream', bp=bp)
|
|
def stream_with_model(ws, model_name=None):
|
|
do_stream(ws, model_name)
|
|
|
|
|
|
def do_stream(ws, model_name):
|
|
event_id = None
|
|
try:
|
|
def send_err_and_quit(quitting_err_msg):
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': 0,
|
|
'text': quitting_err_msg
|
|
}))
|
|
ws.send(json.dumps({
|
|
'event': 'stream_end',
|
|
'message_num': 1
|
|
}))
|
|
ws.close()
|
|
log_to_db(ip=handler.client_ip,
|
|
token=handler.token,
|
|
prompt=input_prompt,
|
|
response=quitting_err_msg,
|
|
gen_time=None,
|
|
parameters=handler.parameters,
|
|
headers=r_headers,
|
|
backend_response_code=response_status_code,
|
|
request_url=r_url,
|
|
backend_url=handler.backend_url,
|
|
response_tokens=None,
|
|
is_error=True
|
|
)
|
|
|
|
if not opts.enable_streaming:
|
|
return 'Streaming disabled', 403
|
|
|
|
r_headers = dict(request.headers)
|
|
r_url = request.url
|
|
message_num = 0
|
|
|
|
while ws.connected:
|
|
message = ws.receive()
|
|
request_valid_json, request_json_body = validate_json(message)
|
|
|
|
if not request_valid_json or not request_json_body.get('prompt'):
|
|
return 'Invalid JSON', 400
|
|
else:
|
|
# We have to do auth ourselves since the details are sent in the message.
|
|
auth_failure = require_api_key(request_json_body)
|
|
if auth_failure:
|
|
return auth_failure
|
|
|
|
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
|
if handler.offline:
|
|
msg = f'{handler.selected_model} is not a valid model choice.'
|
|
print(msg)
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': 0,
|
|
'text': msg
|
|
}))
|
|
return
|
|
|
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
|
# TODO: implement other backends
|
|
raise NotImplementedError
|
|
|
|
input_prompt = request_json_body['prompt']
|
|
response_status_code = 0
|
|
start_time = time.time()
|
|
|
|
err_msg = None
|
|
if handler.is_client_ratelimited():
|
|
r, _ = handler.handle_ratelimited(do_log=False)
|
|
err_msg = r.json['results'][0]['text']
|
|
else:
|
|
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
|
if not request_valid:
|
|
err_msg = invalid_response[0].json['results'][0]['text']
|
|
if err_msg:
|
|
send_err_and_quit(err_msg)
|
|
return
|
|
|
|
handler.parameters, _ = handler.get_parameters()
|
|
handler.prompt = input_prompt
|
|
handler.request_json_body = {
|
|
'prompt': handler.prompt,
|
|
**handler.parameters
|
|
}
|
|
|
|
event = None
|
|
if not handler.is_client_ratelimited():
|
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
|
if not event:
|
|
r = handler.handle_ratelimited()
|
|
send_err_and_quit(r[0].data)
|
|
return
|
|
event_id = event.event_id
|
|
|
|
_, stream_name, error_msg = event.wait()
|
|
if error_msg == 'closed':
|
|
ws.close(reason=1014, message='Request Timeout')
|
|
return
|
|
|
|
stream_redis = Redis(db=8)
|
|
generated_text = ''
|
|
|
|
try:
|
|
last_id = '0-0' # The ID of the last entry we read.
|
|
while True:
|
|
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
|
if not stream_data:
|
|
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
|
return
|
|
else:
|
|
for stream_index, item in stream_data[0][1]:
|
|
last_id = stream_index
|
|
data = ujson.loads(item[b'data'])
|
|
if data['error']:
|
|
print(data['error'])
|
|
send_err_and_quit('Encountered exception while streaming.')
|
|
return
|
|
elif data['new']:
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': data['new']
|
|
}))
|
|
message_num += 1
|
|
generated_text = generated_text + data['new']
|
|
elif data['completed']:
|
|
return
|
|
except:
|
|
send_err_and_quit('Encountered exception while streaming.')
|
|
traceback.print_exc()
|
|
finally:
|
|
try:
|
|
ws.send(json.dumps({
|
|
'event': 'stream_end',
|
|
'message_num': message_num
|
|
}))
|
|
except:
|
|
# The client closed the stream.
|
|
pass
|
|
if stream_name:
|
|
stream_redis.delete(stream_name)
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
log_to_db(ip=handler.client_ip,
|
|
token=handler.token,
|
|
prompt=input_prompt,
|
|
response=generated_text,
|
|
gen_time=elapsed_time,
|
|
parameters=handler.parameters,
|
|
headers=r_headers,
|
|
backend_response_code=response_status_code,
|
|
request_url=r_url,
|
|
backend_url=handler.backend_url
|
|
)
|
|
finally:
|
|
if event_id:
|
|
redis.publish(f'notifications:{event_id}', 'canceled')
|
|
try:
|
|
# Must close the connection or greenlets will complain.
|
|
ws.close()
|
|
except:
|
|
pass
|