import json import time import traceback from flask import request from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...database.database import increment_token_uses, log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize from ...stream import sock # TODO: have workers process streaming requests # TODO: make sure to log the token as well (seems to be missing in the DB right now) @sock.route('/api/v1/stream') def stream(ws): if not opts.enable_streaming: # TODO: return a formatted ST error message return 'disabled', 401 message_num = 0 while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': 'Invalid JSON' })) message_num += 1 else: if opts.mode != 'vllm': # TODO: implement other backends raise NotImplementedError handler = OobaRequestHandler(request, request_json_body) token = request_json_body.get('X-API-KEY') generated_text = '' input_prompt = None response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': invalid_response })) else: input_prompt = request_json_body['prompt'] msg_to_backend = { **handler.parameters, 'prompt': input_prompt, 'stream': True, } try: response = generator(msg_to_backend) # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string json_obj = json.loads(json_str) try: new = json_obj['text'][0].split(input_prompt + generated_text)[1] except IndexError: # ???? continue ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': new })) message_num += 1 generated_text = generated_text + new partial_response = b'' # Reset the partial response # If there is no more data, break the loop if not chunk: break response.close() end_time = time.time() elapsed_time = end_time - start_time generated_tokens = tokenize(generated_text) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) except: generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') generated_tokens = tokenize(generated_text) traceback.print_exc() ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num }))