begin streaming rewrite

This commit is contained in:
Cyberes 2023-10-16 00:18:05 -06:00
parent 24aab3cd93
commit 151b3e4769
5 changed files with 137 additions and 141 deletions

View File

@ -1,8 +1,10 @@
import json import json
import pickle
import time import time
import traceback import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp from . import openai_bp, openai_model_bp
@ -11,7 +13,6 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -64,24 +65,18 @@ def openai_chat_completions(model_name=None):
# Prevent issues on the backend. # Prevent issues on the backend.
return 'Invalid prompt', 400 return 'Invalid prompt', 400
event_id = None # Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
else: else:
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
start_time = time.time() event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
@ -97,27 +92,6 @@ def openai_chat_completions(model_name=None):
) )
return handler.handle_ratelimited() return handler.handle_ratelimited()
# Once the worker receives our streaming request, it will tell us we are ready
# to begin inference.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if time.time() - start_time >= opts.backend_generate_request_timeout:
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
# This shouldn't happen because the best model should be auto-selected.
return return_invalid_model_err(handler.request_json_body['model'])
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return return_invalid_model_err(handler.request_json_body['model'])
try: try:
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
@ -125,68 +99,62 @@ def openai_chat_completions(model_name=None):
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
try: try:
response = generator(msg_to_backend, handler.backend_url) while True:
generated_text = '' stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
partial_response = b'' if not stream_data:
for chunk in response.iter_content(chunk_size=1): print("No message received in 30 seconds, closing stream.")
partial_response += chunk yield 'data: [DONE]\n\n'
if partial_response.endswith(b'\x00'): else:
json_strs = partial_response.split(b'\x00') for r_timestamp, item in stream_data[0][1]:
for json_str in json_strs: timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
if json_str: data = pickle.loads(item[b'data'])
try: if data['error']:
json_obj = json.loads(json_str.decode()) yield 'data: [DONE]\n\n'
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] return
generated_text = generated_text + new elif data['new']:
redis.publish(event_id, 'chunk') # Keepalive response = {
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}", "id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": timestamp,
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": data['new']
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' generated_text = generated_text + data['new']
yield 'data: [DONE]\n\n' yield f'data: {json.dumps(response)}\n\n'
end_time = time.time() elif data['completed']:
elapsed_time = end_time - start_time yield 'data: [DONE]\n\n'
log_to_db( end_time = time.time()
handler.client_ip, elapsed_time = end_time - start_time
handler.token, log_to_db(
handler.prompt, handler.client_ip,
generated_text, handler.token,
elapsed_time, handler.prompt,
handler.parameters, generated_text,
r_headers, elapsed_time,
200, handler.parameters,
r_url, r_headers,
handler.backend_url, 200,
) r_url,
except GeneratorExit: handler.backend_url,
yield 'data: [DONE]\n\n' )
except: return
# AttributeError: 'bool' object has no attribute 'iter_content' except (Exception, GeneratorExit):
traceback.print_exc() traceback.print_exc()
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
finally: finally:
# After completing inference, we need to tell the worker we are finished. stream_redis.delete(stream_name)
if event_id: # may be None if ratelimited.
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:

View File

@ -27,7 +27,7 @@ class RedisPriorityQueue:
self.name = name self.name = name
self.redis = RedisCustom(name, db=db) self.redis = RedisCustom(name, db=db)
def put(self, item, priority, selected_model): def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
assert item is not None assert item is not None
assert priority is not None assert priority is not None
assert selected_model is not None assert selected_model is not None
@ -43,7 +43,7 @@ class RedisPriorityQueue:
return None # reject the request return None # reject the request
timestamp = time.time() timestamp = time.time()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority}) self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count') self.increment_ip_count(item[1], 'queued_ip_count')
return event return event
@ -106,6 +106,7 @@ class DataEvent:
self.redis.publish(self.event_id, pickle.dumps(data)) self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self): def wait(self):
# TODO: implement timeout
for item in self.pubsub.listen(): for item in self.pubsub.listen():
if item['type'] == 'message': if item['type'] == 'message':
return pickle.loads(item['data']) return pickle.loads(item['data'])
@ -151,9 +152,9 @@ class PriorityQueue:
count += queue.get_queued_ip_count(client_ip) count += queue.get_queued_ip_count(client_ip)
return count return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str): def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
queue = RedisPriorityQueue(backend_url) queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model) return queue.put(item, priority, selected_model, do_stream)
def activity(self): def activity(self):
lines = [] lines = []

View File

@ -0,0 +1,32 @@
import time
from redis import Redis
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
# NOT NEEDED
def cleaner():
r = Redis(db=8)
stream_info = {}
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
processed_streams = []
for stream in all_streams:
stream = stream.decode()
current_size = r.xlen(stream)
# If the stream is new or its size has changed, update the size and time in the dictionary
if stream not in stream_info or current_size != stream_info[stream]['size']:
stream_info[stream] = {'size': current_size, 'time': time.time()}
processed_streams.append(stream)
else:
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -1,90 +1,88 @@
import queue import json
import pickle
import threading import threading
import time
import traceback import traceback
from uuid import uuid4 from uuid import uuid4
from redis.client import PubSub from redis import Redis
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
class ListenerThread(threading.Thread): STREAM_NAME_PREFIX = 'stream'
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
threading.Thread.__init__(self)
self.pubsub = pubsub
self.listener_queue = listener_queue
self.stop_event = stop_event
def run(self):
while not self.stop_event.is_set(): def get_stream_name(name: str):
message = self.pubsub.get_message() return f'{STREAM_NAME_PREFIX}:{name}'
if message:
self.listener_queue.put(message)
time.sleep(0.1) def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})})
except Exception as e:
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
traceback.print_exc()
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})})
def worker(backend_url): def worker(backend_url):
status_redis = RedisCustom('worker_status') status_redis = RedisCustom('worker_status')
worker_id = uuid4() worker_id = str(uuid4())
status_redis.setp(str(worker_id), None) status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url) redis_queue = RedisPriorityQueue(backend_url)
while True: while True:
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get() (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
backend_info = cluster_config.get_backend(backend_url) backend_info = cluster_config.get_backend(backend_url)
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
stop_event = threading.Event()
q = queue.Queue()
listener = ListenerThread(pubsub, q, stop_event)
listener.start()
if not backend_info['online']: if not backend_info['online']:
redis.publish(event_id, 'offline') # TODO: communicate to caller
# redis.publish(event_id, 'offline')
return return
if not selected_model: if not selected_model:
selected_model = backend_info['model'] selected_model = backend_info['model']
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url) incr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), ('generating', client_ip))
status_redis.setp(str(worker_id), (backend_url, client_ip))
try: try:
if not request_json_body: if do_stream:
# This was a dummy request from the streaming handlers. event = DataEvent(event_id)
# The worker will let the handler do the streaming instead event.set(get_stream_name(worker_id))
# of the worker. The worker will block until the handler msg_to_backend = {
# is finished. Since a lot of ratelimiting and stats are **parameters,
# based off the number of active workers, we must keep 'prompt': request_json_body['prompt'],
# the generation based off the workers. 'stream': True,
start_time = time.time() }
redis.publish(event_id, 'begin') inference_do_stream(worker_id, msg_to_backend, backend_url)
while True:
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
try:
item = q.get(timeout=30)
except queue.Empty:
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
break
if time.time() - start_time >= opts.backend_generate_request_timeout:
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
break
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
break
else: else:
status_redis.setp(str(worker_id), ('generating', client_ip))
# Normal inference (not streaming). # Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url) success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id) event = DataEvent(event_id)
@ -92,8 +90,6 @@ def worker(backend_url):
except: except:
traceback.print_exc() traceback.print_exc()
finally: finally:
stop_event.set() # make sure to stop the listener thread
listener.join()
decrement_ip_count(client_ip, 'processing_ips') decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url) decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None) status_redis.setp(str(worker_id), None)

View File

@ -2,7 +2,6 @@ import time
from threading import Thread from threading import Thread
from llm_server import opts from llm_server import opts
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers from llm_server.workers.inferencer import start_workers