import json import time import traceback import ujson from flask import request from redis import Redis from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import priority_queue from ... import opts from ...custom_redis import redis from ...database.log_to_db import log_to_db from ...sock import sock # Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. # We solve this by splitting the routes @bp.route('/v1/stream') @bp.route('//v1/stream') def stream(model_name=None): return 'This is a websocket endpoint.', 400 @sock.route('/v1/stream', bp=bp) def stream_without_model(ws): do_stream(ws, model_name=None) @sock.route('//v1/stream', bp=bp) def stream_with_model(ws, model_name=None): do_stream(ws, model_name) def do_stream(ws, model_name): event_id = None try: def send_err_and_quit(quitting_err_msg): ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, 'text': quitting_err_msg })) ws.send(json.dumps({ 'event': 'stream_end', 'message_num': 1 })) ws.close() log_to_db(ip=handler.client_ip, token=handler.token, prompt=input_prompt, response=quitting_err_msg, gen_time=None, parameters=handler.parameters, headers=r_headers, backend_response_code=response_status_code, request_url=r_url, backend_url=handler.backend_url, response_tokens=None, is_error=True ) if not opts.enable_streaming: return 'Streaming disabled', 403 r_headers = dict(request.headers) r_url = request.url message_num = 0 while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): return 'Invalid JSON', 400 else: # We have to do auth ourselves since the details are sent in the message. auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) if handler.offline: msg = f'{handler.selected_model} is not a valid model choice.' print(msg) ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, 'text': msg })) return if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() err_msg = None if handler.is_client_ratelimited(): r, _ = handler.handle_ratelimited(do_log=False) err_msg = r.json['results'][0]['text'] else: request_valid, invalid_response = handler.validate_request(prompt=input_prompt) if not request_valid: err_msg = invalid_response[0].json['results'][0]['text'] if err_msg: send_err_and_quit(err_msg) return handler.parameters, _ = handler.get_parameters() handler.prompt = input_prompt handler.request_json_body = { 'prompt': handler.prompt, **handler.parameters } event = None if not handler.is_client_ratelimited(): event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) if not event: r = handler.handle_ratelimited() send_err_and_quit(r[0].data) return event_id = event.event_id _, stream_name, error_msg = event.wait() if error_msg: print('Stream failed to start streaming:', error_msg) ws.close(reason=1014, message='Request Timeout') return stream_redis = Redis(db=8) generated_text = '' try: last_id = '0-0' # The ID of the last entry we read. while True: stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") return else: for stream_index, item in stream_data[0][1]: last_id = stream_index data = ujson.loads(item[b'data']) if data['error']: print(data['error']) send_err_and_quit('Encountered exception while streaming.') return elif data['new']: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': data['new'] })) message_num += 1 generated_text = generated_text + data['new'] elif data['completed']: return except: send_err_and_quit('Encountered exception while streaming.') traceback.print_exc() finally: try: ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) except: # The client closed the stream. pass if stream_name: stream_redis.delete(stream_name) end_time = time.time() elapsed_time = end_time - start_time log_to_db(ip=handler.client_ip, token=handler.token, prompt=input_prompt, response=generated_text, gen_time=elapsed_time, parameters=handler.parameters, headers=r_headers, backend_response_code=response_status_code, request_url=r_url, backend_url=handler.backend_url ) finally: if event_id: redis.publish(f'notifications:{event_id}', 'canceled') try: # Must close the connection or greenlets will complain. ws.close() except: pass