119 lines
4.7 KiB
Python
119 lines
4.7 KiB
Python
import json
|
|
import time
|
|
import traceback
|
|
|
|
from flask import request
|
|
|
|
from ..helpers.client import format_sillytavern_err
|
|
from ..helpers.http import validate_json
|
|
from ..ooba_request_handler import OobaRequestHandler
|
|
from ... import opts
|
|
from ...database.database import log_prompt
|
|
from ...llm.generator import generator
|
|
from ...llm.vllm import tokenize
|
|
from ...stream import sock
|
|
|
|
|
|
# TODO: have workers process streaming requests
|
|
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
|
|
|
@sock.route('/api/v1/stream')
|
|
def stream(ws):
|
|
if not opts.enable_streaming:
|
|
# TODO: return a formatted ST error message
|
|
return 'disabled', 401
|
|
|
|
message_num = 0
|
|
while ws.connected:
|
|
message = ws.receive()
|
|
request_valid_json, request_json_body = validate_json(message)
|
|
if not request_valid_json or not request_json_body.get('prompt'):
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': 'Invalid JSON'
|
|
}))
|
|
message_num += 1
|
|
else:
|
|
if opts.mode != 'vllm':
|
|
# TODO: implement other backends
|
|
raise NotImplementedError
|
|
|
|
handler = OobaRequestHandler(request, request_json_body)
|
|
generated_text = ''
|
|
input_prompt = None
|
|
response_status_code = 0
|
|
start_time = time.time()
|
|
request_valid, invalid_response = handler.validate_request()
|
|
if not request_valid:
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': invalid_response
|
|
}))
|
|
else:
|
|
input_prompt = request_json_body['prompt']
|
|
msg_to_backend = {
|
|
**handler.parameters,
|
|
'prompt': input_prompt,
|
|
'stream': True,
|
|
}
|
|
|
|
try:
|
|
response = generator(msg_to_backend)
|
|
|
|
# Be extra careful when getting attributes from the response object
|
|
try:
|
|
response_status_code = response.status_code
|
|
except:
|
|
response_status_code = 0
|
|
|
|
partial_response = b''
|
|
|
|
for chunk in response.iter_content(chunk_size=1):
|
|
partial_response += chunk
|
|
if partial_response.endswith(b'\x00'):
|
|
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
|
|
json_obj = json.loads(json_str)
|
|
try:
|
|
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
|
except IndexError:
|
|
# ????
|
|
continue
|
|
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': new
|
|
}))
|
|
message_num += 1
|
|
|
|
generated_text = generated_text + new
|
|
partial_response = b'' # Reset the partial response
|
|
|
|
# If there is no more data, break the loop
|
|
if not chunk:
|
|
break
|
|
|
|
response.close()
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
generated_tokens = tokenize(generated_text)
|
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
|
except:
|
|
generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error')
|
|
generated_tokens = tokenize(generated_text)
|
|
traceback.print_exc()
|
|
ws.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': generated_text
|
|
}))
|
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
|
|
|
ws.send(json.dumps({
|
|
'event': 'stream_end',
|
|
'message_num': message_num
|
|
}))
|