import json import time import traceback from flask import request from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...sock import sock # Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. # We solve this by splitting the routes @bp.route('/v1/stream') @bp.route('//v1/stream') def stream(model_name=None): return 'This is a websocket endpoint.', 400 @sock.route('/v1/stream', bp=bp) def stream_without_model(ws): do_stream(ws, model_name=None) @sock.route('//v1/stream', bp=bp) def stream_with_model(ws, model_name=None): do_stream(ws, model_name) def do_stream(ws, model_name): def send_err_and_quit(quitting_err_msg): ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, 'text': quitting_err_msg })) ws.send(json.dumps({ 'event': 'stream_end', 'message_num': 1 })) log_prompt(ip=handler.client_ip, token=handler.token, prompt=input_prompt, response=quitting_err_msg, gen_time=None, parameters=handler.parameters, headers=r_headers, backend_response_code=response_status_code, request_url=r_url, backend_url=handler.cluster_backend_info, response_tokens=None, is_error=True ) if not opts.enable_streaming: return 'Streaming is disabled', 500 r_headers = dict(request.headers) r_url = request.url message_num = 0 try: while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): return 'Invalid JSON', 400 else: auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() err_msg = None if handler.is_client_ratelimited(): r, _ = handler.handle_ratelimited(do_log=False) err_msg = r.json['results'][0]['text'] else: request_valid, invalid_response = handler.validate_request(prompt=input_prompt) if not request_valid: err_msg = invalid_response[0].json['results'][0]['text'] if err_msg: send_err_and_quit(err_msg) return llm_request = { **handler.parameters, 'prompt': input_prompt, 'stream': True, } # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: r, _ = handler.handle_ratelimited() err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return # Wait for a worker to get our request and discard it. _, _, _ = event.wait() try: response = generator(llm_request, handler.backend_url) if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': msg })) else: # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(input_prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue try: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': new })) except: # The has client closed the stream. if response: # Cancel the backend? response.close() # used to log here return message_num += 1 partial_response = b'' # Reset the partial response # If there is no more data, break the loop if not chunk: break if response: response.close() # used to log here except: traceback.print_exc() generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) # used to log here finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') decr_active_workers(handler.selected_model, handler.backend_url) try: ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) except: # The client closed the stream. pass end_time = time.time() elapsed_time = end_time - start_time log_prompt(ip=handler.client_ip, token=handler.token, prompt=input_prompt, response=generated_text, gen_time=elapsed_time, parameters=handler.parameters, headers=r_headers, backend_response_code=response_status_code, request_url=r_url, backend_url=handler.backend_url ) finally: try: # Must close the connection or greenlets will complain. ws.close() except: pass