import json import time import requests from flask import request from ..helpers.client import format_sillytavern_err from ... import opts from ...database import log_prompt from ...helpers import indefinite_article from ...llm.hf_textgen.generate import prepare_json from ...stream import sock # TODO: have workers process streaming requests @sock.route('/api/v1/stream') # TODO: use blueprint route??? def stream(ws): start_time = time.time() if request.headers.get('cf-connecting-ip'): client_ip = request.headers.get('cf-connecting-ip') elif request.headers.get('x-forwarded-for'): client_ip = request.headers.get('x-forwarded-for').split(',')[0] else: client_ip = request.remote_addr token = request.headers.get('X-Api-Key') message_num = 0 while ws.connected: message = ws.receive() data = json.loads(message) if opts.mode == 'hf-textgen': response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False) # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 details = {} generated_text = '' # Iterate over each line in the response for line in response.iter_lines(): # Decode the line to a string line = line.decode('utf-8') # If the line starts with 'data:', remove the prefix and parse the remaining string as JSON if line.startswith('data:'): line = line[5:] json_data = json.loads(line) details = json_data.get('details', {}) generated_text = json_data.get('generated_text', '') if json_data.get('error'): error_type = json_data.get('error_type') error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error' generated_text = format_sillytavern_err( f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}', f'HTTP CODE {response_status_code}') ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) break else: ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': json_data['token']['text'] })) message_num += 1 ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) end_time = time.time() elapsed_time = end_time - start_time parameters = data.copy() del parameters['prompt'] log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])