This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/v1/generate_stream.py

257 lines
11 KiB
Python
Raw Normal View History

2023-08-29 17:56:12 -06:00
import json
import time
2023-09-24 13:27:27 -06:00
import traceback
2023-08-29 17:56:12 -06:00
from flask import request
2023-10-01 00:20:00 -06:00
from . import bp
from ..helpers.http import require_api_key, validate_json
2023-09-23 17:57:23 -06:00
from ..ooba_request_handler import OobaRequestHandler
2023-09-28 08:47:39 -06:00
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
2023-08-29 17:56:12 -06:00
from ... import opts
2023-09-25 12:38:02 -06:00
from ...database.database import log_prompt
2023-09-23 17:57:23 -06:00
from ...llm.generator import generator
2023-09-29 00:09:44 -06:00
from ...sock import sock
2023-08-29 17:56:12 -06:00
2023-08-30 18:53:26 -06:00
2023-10-01 00:20:00 -06:00
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
2023-08-29 17:56:12 -06:00
2023-10-01 01:13:13 -06:00
@bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream')
def stream(model_name=None):
2023-10-01 00:20:00 -06:00
return 'This is a websocket endpoint.', 400
2023-10-01 01:13:13 -06:00
@sock.route('/v1/stream', bp=bp)
2023-10-01 00:20:00 -06:00
def stream_without_model(ws):
do_stream(ws, model_name=None)
@sock.route('/<model_name>/v1/stream', bp=bp)
def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
def do_stream(ws, model_name):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
2023-10-01 00:20:00 -06:00
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
2023-10-01 10:25:32 -06:00
gen_time=None,
2023-10-01 00:20:00 -06:00
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=None,
2023-10-01 00:20:00 -06:00
is_error=True
)
if not opts.enable_streaming:
2023-10-01 00:20:00 -06:00
return 'Streaming is disabled', 500
2023-09-25 18:18:29 -06:00
r_headers = dict(request.headers)
r_url = request.url
2023-09-23 17:57:23 -06:00
message_num = 0
2023-10-01 00:20:00 -06:00
try:
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
2023-10-01 00:20:00 -06:00
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
2023-09-23 17:57:23 -06:00
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
2023-09-26 22:49:53 -06:00
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
2023-09-25 18:18:29 -06:00
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
2023-10-01 16:04:53 -06:00
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
2023-10-01 16:04:53 -06:00
try:
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
}))
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
2023-09-24 13:27:27 -06:00
partial_response = b''
2023-09-24 13:27:27 -06:00
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
2023-10-01 00:20:00 -06:00
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
2023-10-01 00:20:00 -06:00
except:
# The has client closed the stream.
if request:
# Cancel the backend?
request.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None
)
return
message_num += 1
partial_response = b'' # Reset the partial response
2023-09-24 13:27:27 -06:00
# If there is no more data, break the loop
if not chunk:
break
2023-09-24 13:27:27 -06:00
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=not response
)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
if request:
request.close()
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers(handler.selected_model, handler.backend_url)
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None
)
finally:
try:
# Must close the connection or greenlets will complain.
ws.close()
except:
pass