import json import threading import time import traceback from flask import request from ..helpers.client import format_sillytavern_err from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize from ...stream import sock # TODO: have workers process streaming requests # TODO: make sure to log the token as well (seems to be missing in the DB right now) @sock.route('/api/v1/stream') def stream(ws): if not opts.enable_streaming: return 'disabled', 401 r_headers = dict(request.headers) r_url = request.url message_num = 0 while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': 'Invalid JSON' })) message_num += 1 else: if opts.mode != 'vllm': # TODO: implement other backends raise NotImplementedError auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure handler = OobaRequestHandler(request, request_json_body) generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request(prompt=input_prompt) if not request_valid: err_msg = invalid_response[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, 'text': err_msg })) ws.send(json.dumps({ 'event': 'stream_end', 'message_num': 1 })) ws.close() # this is important if we encountered and error and exited early. def background_task(): log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task) thread.start() thread.join() else: msg_to_backend = { **handler.parameters, 'prompt': input_prompt, 'stream': True, } try: response = generator(msg_to_backend) # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 partial_response = b'' # TODO: handle when the backend is offline # Traceback (most recent call last): # File "/srv/server/local-llm-server/llm_server/routes/v1/generate_stream.py", line 91, in stream # for chunk in response.iter_content(chunk_size=1): # ^^^^^^^^^^^^^^^^^^^^^ # AttributeError: 'NoneType' object has no attribute 'iter_content' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(input_prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue try: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': new })) except: end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) return message_num += 1 partial_response = b'' # Reset the partial response # If there is no more data, break the loop if not chunk: break response.close() end_time = time.time() elapsed_time = end_time - start_time def background_task_success(): log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task_success) thread.start() thread.join() except: generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8') traceback.print_exc() ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) def background_task_exception(): generated_tokens = tokenize(generated_text) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task_exception) thread.start() thread.join() try: ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) except: end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) ws.close() # this is important if we encountered and error and exited early.