From 2678102153b6ed26ffa87eae5d66d481b56366a9 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 24 Sep 2023 13:27:27 -0600 Subject: [PATCH] handle error while streaming --- llm_server/routes/v1/generate_stream.py | 91 ++++++++++++++----------- other/vllm/vllm_api_server.py | 2 + 2 files changed, 55 insertions(+), 38 deletions(-) diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 10cd720..2575194 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,8 +1,10 @@ import json import time +import traceback from flask import request +from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts @@ -13,6 +15,7 @@ from ...stream import sock # TODO: have workers process streaming requests +# TODO: make sure to log the token as well (seems to be missing in the DB right now) @sock.route('/api/v1/stream') def stream(ws): @@ -55,47 +58,59 @@ def stream(ws): 'prompt': input_prompt, 'stream': True, } - response = generator(msg_to_backend) - # Be extra careful when getting attributes from the response object try: - response_status_code = response.status_code + response = generator(msg_to_backend) + + # Be extra careful when getting attributes from the response object + try: + response_status_code = response.status_code + except: + response_status_code = 0 + + partial_response = b'' + + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string + json_obj = json.loads(json_str) + try: + new = json_obj['text'][0].split(input_prompt + generated_text)[1] + except IndexError: + # ???? + continue + + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': new + })) + message_num += 1 + + generated_text = generated_text + new + partial_response = b'' # Reset the partial response + + # If there is no more data, break the loop + if not chunk: + break + + response.close() + + end_time = time.time() + elapsed_time = end_time - start_time + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) except: - response_status_code = 0 - - partial_response = b'' - - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string - json_obj = json.loads(json_str) - try: - new = json_obj['text'][0].split(input_prompt + generated_text)[1] - except IndexError: - # ???? - continue - - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': new - })) - message_num += 1 - - generated_text = generated_text + new - partial_response = b'' # Reset the partial response - - # If there is no more data, break the loop - if not chunk: - break - - response.close() - - end_time = time.time() - elapsed_time = end_time - start_time - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error') + generated_tokens = tokenize(generated_text) + traceback.print_exc() + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': generated_text + })) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) ws.send(json.dumps({ 'event': 'stream_end', diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py index 0b067fd..502c996 100644 --- a/other/vllm/vllm_api_server.py +++ b/other/vllm/vllm_api_server.py @@ -19,6 +19,8 @@ app = FastAPI() served_model = None +# TODO: figure out ROPE scaling +# TODO: make sure error messages are returned in the response @app.get("/model") async def generate(request: Request) -> Response: