import json import threading import time import traceback from typing import Union from flask import request from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize from ...stream import sock # TODO: have workers process streaming requests # TODO: make sure to log the token as well (seems to be missing in the DB right now) @sock.route('/api/v1/stream') def stream(ws): def send_err_and_quit(quitting_err_msg): ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, 'text': quitting_err_msg })) ws.send(json.dumps({ 'event': 'stream_end', 'message_num': 1 })) ws.close() log_in_bg(quitting_err_msg, is_error=True) def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): def background_task_exception(): generated_tokens = tokenize(generated_text_bg) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task_exception) thread.start() thread.join() if not opts.enable_streaming: return 'Streaming is disabled', 401 r_headers = dict(request.headers) r_url = request.url message_num = 0 while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): return 'Invalid JSON', 400 else: if opts.mode != 'vllm': # TODO: implement other backends raise NotImplementedError auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure handler = OobaRequestHandler(request, request_json_body) generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() err_msg = None if handler.is_client_ratelimited(): r, _ = handler.handle_ratelimited(do_log=False) err_msg = r.json['results'][0]['text'] else: request_valid, invalid_response = handler.validate_request(prompt=input_prompt) if not request_valid: err_msg = invalid_response[0].json['results'][0]['text'] if err_msg: send_err_and_quit(err_msg) return llm_request = { **handler.parameters, 'prompt': input_prompt, 'stream': True, } # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) if not event: r, _ = handler.handle_ratelimited() err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return try: response = generator(llm_request) if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': msg })) else: # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(input_prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue try: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': new })) except: # The has client closed the stream. if request: request.close() ws.close() end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) return message_num += 1 partial_response = b'' # Reset the partial response # If there is no more data, break the loop if not chunk: break end_time = time.time() elapsed_time = end_time - start_time log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) except: traceback.print_exc() generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) if request: request.close() ws.close() log_in_bg(generated_text, is_error=True, status_code=response_status_code) return finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') decr_active_workers() try: ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) except: # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) ws.close() # this is important if we encountered and error and exited early.