local-llm-server/llm_server/routes/v1/generate_stream.py

164 lines
7.1 KiB
Python
Raw Normal View History

2023-08-29 17:56:12 -06:00
import json
import threading
2023-08-29 17:56:12 -06:00
import time
2023-09-24 13:27:27 -06:00
import traceback
2023-08-29 17:56:12 -06:00
from flask import request
2023-09-24 13:27:27 -06:00
from ..helpers.client import format_sillytavern_err
from ..helpers.http import require_api_key, validate_json
2023-09-23 17:57:23 -06:00
from ..ooba_request_handler import OobaRequestHandler
2023-08-29 17:56:12 -06:00
from ... import opts
2023-09-25 12:38:02 -06:00
from ...database.database import log_prompt
2023-09-23 17:57:23 -06:00
from ...llm.generator import generator
from ...llm.vllm import tokenize
2023-08-29 17:56:12 -06:00
from ...stream import sock
2023-08-30 18:53:26 -06:00
# TODO: have workers process streaming requests
2023-09-24 13:27:27 -06:00
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
2023-08-29 17:56:12 -06:00
2023-09-23 22:30:59 -06:00
@sock.route('/api/v1/stream')
2023-08-29 17:56:12 -06:00
def stream(ws):
if not opts.enable_streaming:
# TODO: return a formatted ST error message
return 'disabled', 401
auth_failure = require_api_key()
if auth_failure:
return auth_failure
2023-09-25 18:18:29 -06:00
r_headers = dict(request.headers)
r_url = request.url
2023-09-23 17:57:23 -06:00
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': 'Invalid JSON'
}))
message_num += 1
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
handler = OobaRequestHandler(request, request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
2023-09-23 17:57:23 -06:00
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
2023-09-23 17:57:23 -06:00
if not request_valid:
2023-09-25 18:18:29 -06:00
err_msg = invalid_response[0].json['results'][0]['text']
2023-09-23 17:57:23 -06:00
ws.send(json.dumps({
'event': 'text_stream',
2023-09-26 22:09:11 -06:00
'message_num': 0,
2023-09-25 18:18:29 -06:00
'text': err_msg
2023-09-23 17:57:23 -06:00
}))
2023-09-26 22:09:11 -06:00
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close() # this is important if we encountered and error and exited early.
2023-09-25 18:18:29 -06:00
def background_task():
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
2023-09-23 17:57:23 -06:00
else:
msg_to_backend = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
try:
2023-09-24 13:27:27 -06:00
response = generator(msg_to_backend)
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
2023-09-25 12:38:02 -06:00
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
2023-09-25 12:38:02 -06:00
generated_text = generated_text + new
except IndexError:
# ????
continue
2023-09-26 22:09:11 -06:00
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
return
message_num += 1
partial_response = b'' # Reset the partial response
2023-09-24 13:27:27 -06:00
# If there is no more data, break the loop
if not chunk:
break
response.close()
end_time = time.time()
elapsed_time = end_time - start_time
2023-09-25 18:18:29 -06:00
def background_task_success():
2023-09-26 22:09:11 -06:00
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
# TODO: use async/await instead of threads
2023-09-25 18:18:29 -06:00
thread = threading.Thread(target=background_task_success)
thread.start()
thread.join()
2023-09-23 17:57:23 -06:00
except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
2023-09-24 13:27:27 -06:00
traceback.print_exc()
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
2023-09-25 18:18:29 -06:00
def background_task_exception():
generated_tokens = tokenize(generated_text)
2023-09-25 18:18:29 -06:00
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
2023-09-25 18:18:29 -06:00
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
2023-09-26 22:09:11 -06:00
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
ws.close() # this is important if we encountered and error and exited early.
except:
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))