begin streaming rewrite

This commit is contained in:
Cyberes 2023-10-16 00:18:05 -06:00
parent 24aab3cd93
commit 151b3e4769
5 changed files with 137 additions and 141 deletions

View File

@ -1,8 +1,10 @@
import json
import pickle
import time
import traceback
from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
@ -11,7 +13,6 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -64,24 +65,18 @@ def openai_chat_completions(model_name=None):
# Prevent issues on the backend.
return 'Invalid prompt', 400
event_id = None
# Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else:
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
event = None
if not handler.is_client_ratelimited():
start_time = time.time()
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
@ -97,27 +92,6 @@ def openai_chat_completions(model_name=None):
)
return handler.handle_ratelimited()
# Once the worker receives our streaming request, it will tell us we are ready
# to begin inference.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if time.time() - start_time >= opts.backend_generate_request_timeout:
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
# This shouldn't happen because the best model should be auto-selected.
return return_invalid_model_err(handler.request_json_body['model'])
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return return_invalid_model_err(handler.request_json_body['model'])
try:
r_headers = dict(request.headers)
r_url = request.url
@ -125,68 +99,62 @@ def openai_chat_completions(model_name=None):
oai_string = generate_oai_string(30)
def generate():
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
try:
response = generator(msg_to_backend, handler.backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
redis.publish(event_id, 'chunk') # Keepalive
except IndexError:
# ????
continue
data = {
while True:
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
if not stream_data:
print("No message received in 30 seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for r_timestamp, item in stream_data[0][1]:
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data'])
if data['error']:
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
"content": data['new']
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
except GeneratorExit:
yield 'data: [DONE]\n\n'
except:
# AttributeError: 'bool' object has no attribute 'iter_content'
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except (Exception, GeneratorExit):
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
# After completing inference, we need to tell the worker we are finished.
if event_id: # may be None if ratelimited.
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:

View File

@ -27,7 +27,7 @@ class RedisPriorityQueue:
self.name = name
self.redis = RedisCustom(name, db=db)
def put(self, item, priority, selected_model):
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
assert item is not None
assert priority is not None
assert selected_model is not None
@ -43,7 +43,7 @@ class RedisPriorityQueue:
return None # reject the request
timestamp = time.time()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority})
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
return event
@ -106,6 +106,7 @@ class DataEvent:
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
# TODO: implement timeout
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
@ -151,9 +152,9 @@ class PriorityQueue:
count += queue.get_queued_ip_count(client_ip)
return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str):
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model)
return queue.put(item, priority, selected_model, do_stream)
def activity(self):
lines = []

View File

@ -0,0 +1,32 @@
import time
from redis import Redis
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
# NOT NEEDED
def cleaner():
r = Redis(db=8)
stream_info = {}
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
processed_streams = []
for stream in all_streams:
stream = stream.decode()
current_size = r.xlen(stream)
# If the stream is new or its size has changed, update the size and time in the dictionary
if stream not in stream_info or current_size != stream_info[stream]['size']:
stream_info[stream] = {'size': current_size, 'time': time.time()}
processed_streams.append(stream)
else:
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -1,90 +1,88 @@
import queue
import json
import pickle
import threading
import time
import traceback
from uuid import uuid4
from redis.client import PubSub
from redis import Redis
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.custom_redis import RedisCustom
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
class ListenerThread(threading.Thread):
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
threading.Thread.__init__(self)
self.pubsub = pubsub
self.listener_queue = listener_queue
self.stop_event = stop_event
STREAM_NAME_PREFIX = 'stream'
def run(self):
while not self.stop_event.is_set():
message = self.pubsub.get_message()
if message:
self.listener_queue.put(message)
time.sleep(0.1)
def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})})
except Exception as e:
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
traceback.print_exc()
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})})
def worker(backend_url):
status_redis = RedisCustom('worker_status')
worker_id = uuid4()
worker_id = str(uuid4())
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get()
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
backend_info = cluster_config.get_backend(backend_url)
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
stop_event = threading.Event()
q = queue.Queue()
listener = ListenerThread(pubsub, q, stop_event)
listener.start()
if not backend_info['online']:
redis.publish(event_id, 'offline')
# TODO: communicate to caller
# redis.publish(event_id, 'offline')
return
if not selected_model:
selected_model = backend_info['model']
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), (backend_url, client_ip))
status_redis.setp(str(worker_id), ('generating', client_ip))
try:
if not request_json_body:
# This was a dummy request from the streaming handlers.
# The worker will let the handler do the streaming instead
# of the worker. The worker will block until the handler
# is finished. Since a lot of ratelimiting and stats are
# based off the number of active workers, we must keep
# the generation based off the workers.
start_time = time.time()
redis.publish(event_id, 'begin')
while True:
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
try:
item = q.get(timeout=30)
except queue.Empty:
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
break
if time.time() - start_time >= opts.backend_generate_request_timeout:
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
break
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
break
if do_stream:
event = DataEvent(event_id)
event.set(get_stream_name(worker_id))
msg_to_backend = {
**parameters,
'prompt': request_json_body['prompt'],
'stream': True,
}
inference_do_stream(worker_id, msg_to_backend, backend_url)
else:
status_redis.setp(str(worker_id), ('generating', client_ip))
# Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
@ -92,8 +90,6 @@ def worker(backend_url):
except:
traceback.print_exc()
finally:
stop_event.set() # make sure to stop the listener thread
listener.join()
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None)

View File

@ -2,7 +2,6 @@ import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers