handle error while streaming
This commit is contained in:
parent
cb99c3490e
commit
2678102153
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import request
|
||||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
|
@ -13,6 +15,7 @@ from ...stream import sock
|
|||
|
||||
|
||||
# TODO: have workers process streaming requests
|
||||
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
||||
|
||||
@sock.route('/api/v1/stream')
|
||||
def stream(ws):
|
||||
|
@ -55,47 +58,59 @@ def stream(ws):
|
|||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
response = generator(msg_to_backend)
|
||||
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
response = generator(msg_to_backend)
|
||||
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
partial_response = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
|
||||
json_obj = json.loads(json_str)
|
||||
try:
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
message_num += 1
|
||||
|
||||
generated_text = generated_text + new
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
response.close()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
partial_response = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
|
||||
json_obj = json.loads(json_str)
|
||||
try:
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
message_num += 1
|
||||
|
||||
generated_text = generated_text + new
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
response.close()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error')
|
||||
generated_tokens = tokenize(generated_text)
|
||||
traceback.print_exc()
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
|
|
|
@ -19,6 +19,8 @@ app = FastAPI()
|
|||
|
||||
served_model = None
|
||||
|
||||
# TODO: figure out ROPE scaling
|
||||
# TODO: make sure error messages are returned in the response
|
||||
|
||||
@app.get("/model")
|
||||
async def generate(request: Request) -> Response:
|
||||
|
|
Reference in New Issue