2023-08-29 17:56:12 -06:00
|
|
|
import json
|
2023-09-25 12:30:40 -06:00
|
|
|
import threading
|
2023-08-29 17:56:12 -06:00
|
|
|
import time
|
2023-09-24 13:27:27 -06:00
|
|
|
import traceback
|
2023-09-27 21:15:54 -06:00
|
|
|
from typing import Union
|
2023-08-29 17:56:12 -06:00
|
|
|
|
|
|
|
from flask import request
|
|
|
|
|
2023-09-25 12:30:40 -06:00
|
|
|
from ..helpers.http import require_api_key, validate_json
|
2023-09-23 17:57:23 -06:00
|
|
|
from ..ooba_request_handler import OobaRequestHandler
|
2023-09-28 08:47:39 -06:00
|
|
|
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
2023-08-29 17:56:12 -06:00
|
|
|
from ... import opts
|
2023-09-25 12:38:02 -06:00
|
|
|
from ...database.database import log_prompt
|
2023-09-23 17:57:23 -06:00
|
|
|
from ...llm.generator import generator
|
|
|
|
from ...llm.vllm import tokenize
|
2023-08-29 17:56:12 -06:00
|
|
|
from ...stream import sock
|
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
|
|
|
|
# TODO: have workers process streaming requests
|
2023-09-24 13:27:27 -06:00
|
|
|
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
2023-08-29 17:56:12 -06:00
|
|
|
|
2023-09-23 22:30:59 -06:00
|
|
|
@sock.route('/api/v1/stream')
|
2023-08-29 17:56:12 -06:00
|
|
|
def stream(ws):
|
2023-09-27 21:15:54 -06:00
|
|
|
def send_err_and_quit(quitting_err_msg):
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'text_stream',
|
|
|
|
'message_num': 0,
|
|
|
|
'text': quitting_err_msg
|
|
|
|
}))
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'stream_end',
|
|
|
|
'message_num': 1
|
|
|
|
}))
|
|
|
|
ws.close()
|
|
|
|
log_in_bg(quitting_err_msg, is_error=True)
|
|
|
|
|
|
|
|
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
2023-09-28 01:34:15 -06:00
|
|
|
|
2023-09-27 21:15:54 -06:00
|
|
|
def background_task_exception():
|
|
|
|
generated_tokens = tokenize(generated_text_bg)
|
|
|
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
|
|
|
|
|
|
|
|
# TODO: use async/await instead of threads
|
|
|
|
thread = threading.Thread(target=background_task_exception)
|
|
|
|
thread.start()
|
|
|
|
thread.join()
|
|
|
|
|
2023-09-14 14:05:50 -06:00
|
|
|
if not opts.enable_streaming:
|
2023-09-27 21:15:54 -06:00
|
|
|
return 'Streaming is disabled', 401
|
2023-09-14 14:05:50 -06:00
|
|
|
|
2023-09-25 18:18:29 -06:00
|
|
|
r_headers = dict(request.headers)
|
|
|
|
r_url = request.url
|
2023-09-23 17:57:23 -06:00
|
|
|
message_num = 0
|
|
|
|
while ws.connected:
|
|
|
|
message = ws.receive()
|
|
|
|
request_valid_json, request_json_body = validate_json(message)
|
|
|
|
if not request_valid_json or not request_json_body.get('prompt'):
|
2023-09-27 21:15:54 -06:00
|
|
|
return 'Invalid JSON', 400
|
2023-09-23 17:57:23 -06:00
|
|
|
else:
|
|
|
|
if opts.mode != 'vllm':
|
|
|
|
# TODO: implement other backends
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-09-26 22:49:53 -06:00
|
|
|
auth_failure = require_api_key(request_json_body)
|
|
|
|
if auth_failure:
|
|
|
|
return auth_failure
|
|
|
|
|
2023-09-23 17:57:23 -06:00
|
|
|
handler = OobaRequestHandler(request, request_json_body)
|
|
|
|
generated_text = ''
|
2023-09-25 22:32:48 -06:00
|
|
|
input_prompt = request_json_body['prompt']
|
2023-09-23 17:57:23 -06:00
|
|
|
response_status_code = 0
|
|
|
|
start_time = time.time()
|
2023-09-25 18:18:29 -06:00
|
|
|
|
2023-09-27 21:15:54 -06:00
|
|
|
err_msg = None
|
|
|
|
if handler.is_client_ratelimited():
|
2023-09-28 01:34:15 -06:00
|
|
|
r, _ = handler.handle_ratelimited(do_log=False)
|
2023-09-27 21:15:54 -06:00
|
|
|
err_msg = r.json['results'][0]['text']
|
2023-09-23 17:57:23 -06:00
|
|
|
else:
|
2023-09-27 21:15:54 -06:00
|
|
|
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
|
|
|
if not request_valid:
|
|
|
|
err_msg = invalid_response[0].json['results'][0]['text']
|
|
|
|
if err_msg:
|
|
|
|
send_err_and_quit(err_msg)
|
|
|
|
return
|
|
|
|
|
|
|
|
llm_request = {
|
|
|
|
**handler.parameters,
|
|
|
|
'prompt': input_prompt,
|
|
|
|
'stream': True,
|
|
|
|
}
|
|
|
|
|
|
|
|
# Add a dummy event to the queue and wait for it to reach a worker
|
|
|
|
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
|
|
|
|
if not event:
|
|
|
|
r, _ = handler.handle_ratelimited()
|
|
|
|
err_msg = r.json['results'][0]['text']
|
|
|
|
send_err_and_quit(err_msg)
|
|
|
|
return
|
|
|
|
try:
|
|
|
|
response = generator(llm_request)
|
|
|
|
if not response:
|
|
|
|
error_msg = 'Failed to reach backend while streaming.'
|
|
|
|
print('Streaming failed:', error_msg)
|
|
|
|
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'text_stream',
|
|
|
|
'message_num': message_num,
|
|
|
|
'text': msg
|
|
|
|
}))
|
|
|
|
else:
|
2023-09-24 13:27:27 -06:00
|
|
|
# Be extra careful when getting attributes from the response object
|
|
|
|
try:
|
|
|
|
response_status_code = response.status_code
|
|
|
|
except:
|
|
|
|
response_status_code = 0
|
|
|
|
|
|
|
|
partial_response = b''
|
|
|
|
|
|
|
|
for chunk in response.iter_content(chunk_size=1):
|
|
|
|
partial_response += chunk
|
|
|
|
if partial_response.endswith(b'\x00'):
|
2023-09-25 12:30:40 -06:00
|
|
|
json_strs = partial_response.split(b'\x00')
|
|
|
|
for json_str in json_strs:
|
|
|
|
if json_str:
|
|
|
|
try:
|
2023-09-25 12:38:02 -06:00
|
|
|
json_obj = json.loads(json_str.decode())
|
2023-09-25 12:30:40 -06:00
|
|
|
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
2023-09-25 12:38:02 -06:00
|
|
|
generated_text = generated_text + new
|
2023-09-25 12:30:40 -06:00
|
|
|
except IndexError:
|
|
|
|
# ????
|
|
|
|
continue
|
2023-09-26 22:09:11 -06:00
|
|
|
try:
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'text_stream',
|
|
|
|
'message_num': message_num,
|
|
|
|
'text': new
|
|
|
|
}))
|
|
|
|
except:
|
2023-09-28 09:55:31 -06:00
|
|
|
# The has client closed the stream.
|
|
|
|
if request:
|
|
|
|
request.close()
|
|
|
|
ws.close()
|
2023-09-26 22:09:11 -06:00
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
|
|
|
return
|
2023-09-25 12:30:40 -06:00
|
|
|
|
|
|
|
message_num += 1
|
|
|
|
partial_response = b'' # Reset the partial response
|
2023-09-24 13:27:27 -06:00
|
|
|
|
|
|
|
# If there is no more data, break the loop
|
|
|
|
if not chunk:
|
|
|
|
break
|
|
|
|
|
2023-09-27 21:15:54 -06:00
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
|
|
|
except:
|
|
|
|
traceback.print_exc()
|
|
|
|
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'text_stream',
|
|
|
|
'message_num': message_num,
|
|
|
|
'text': generated_text
|
|
|
|
}))
|
2023-09-28 09:55:31 -06:00
|
|
|
if request:
|
|
|
|
request.close()
|
2023-09-28 01:34:15 -06:00
|
|
|
ws.close()
|
2023-09-28 09:55:31 -06:00
|
|
|
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
|
2023-09-28 01:34:15 -06:00
|
|
|
return
|
2023-09-27 21:15:54 -06:00
|
|
|
finally:
|
|
|
|
# The worker incremented it, we'll decrement it.
|
|
|
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
2023-09-28 08:47:39 -06:00
|
|
|
decr_active_workers()
|
2023-09-26 22:09:11 -06:00
|
|
|
try:
|
|
|
|
ws.send(json.dumps({
|
|
|
|
'event': 'stream_end',
|
|
|
|
'message_num': message_num
|
|
|
|
}))
|
|
|
|
except:
|
2023-09-27 21:15:54 -06:00
|
|
|
# The client closed the stream.
|
2023-09-26 22:09:11 -06:00
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
2023-09-26 22:49:53 -06:00
|
|
|
ws.close() # this is important if we encountered and error and exited early.
|