handle error while streaming

This commit is contained in:
Cyberes 2023-09-24 13:27:27 -06:00
parent cb99c3490e
commit 2678102153
2 changed files with 55 additions and 38 deletions

View File

@ -1,8 +1,10 @@
import json import json
import time import time
import traceback
from flask import request from flask import request
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
@ -13,6 +15,7 @@ from ...stream import sock
# TODO: have workers process streaming requests # TODO: have workers process streaming requests
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
@sock.route('/api/v1/stream') @sock.route('/api/v1/stream')
def stream(ws): def stream(ws):
@ -55,6 +58,8 @@ def stream(ws):
'prompt': input_prompt, 'prompt': input_prompt,
'stream': True, 'stream': True,
} }
try:
response = generator(msg_to_backend) response = generator(msg_to_backend)
# Be extra careful when getting attributes from the response object # Be extra careful when getting attributes from the response object
@ -96,6 +101,16 @@ def stream(ws):
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
generated_tokens = tokenize(generated_text) generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error')
generated_tokens = tokenize(generated_text)
traceback.print_exc()
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',

View File

@ -19,6 +19,8 @@ app = FastAPI()
served_model = None served_model = None
# TODO: figure out ROPE scaling
# TODO: make sure error messages are returned in the response
@app.get("/model") @app.get("/model")
async def generate(request: Request) -> Response: async def generate(request: Request) -> Response: