diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index b1e74be..a840070 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,8 +1,10 @@ import json +import pickle import time import traceback from flask import Response, jsonify, request +from redis import Redis from llm_server.custom_redis import redis from . import openai_bp, openai_model_bp @@ -11,7 +13,6 @@ from ..openai_request_handler import OpenAIRequestHandler from ..queue import priority_queue from ... import opts from ...database.log_to_db import log_to_db -from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -64,24 +65,18 @@ def openai_chat_completions(model_name=None): # Prevent issues on the backend. return 'Invalid prompt', 400 - event_id = None + # Need to set the prompt in the JSON body since that's what the inference worker expects. + handler.request_json_body['prompt'] = handler.prompt + start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response else: - msg_to_backend = { - **handler.parameters, - 'prompt': handler.prompt, - 'stream': True, - } - event = None if not handler.is_client_ratelimited(): - start_time = time.time() - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) if not event: log_to_db( handler.client_ip, @@ -97,27 +92,6 @@ def openai_chat_completions(model_name=None): ) return handler.handle_ratelimited() - # Once the worker receives our streaming request, it will tell us we are ready - # to begin inference. - event_id = event.event_id - pubsub = redis.pubsub() - pubsub.subscribe(event_id) - for item in pubsub.listen(): - if time.time() - start_time >= opts.backend_generate_request_timeout: - raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body) - if item['type'] == 'message': - msg = item['data'].decode('utf-8') - if msg == 'begin': - break - elif msg == 'offline': - # This shouldn't happen because the best model should be auto-selected. - return return_invalid_model_err(handler.request_json_body['model']) - time.sleep(0.1) - - # Double check the model is still online - if not handler.check_online(): - return return_invalid_model_err(handler.request_json_body['model']) - try: r_headers = dict(request.headers) r_url = request.url @@ -125,68 +99,62 @@ def openai_chat_completions(model_name=None): oai_string = generate_oai_string(30) def generate(): + stream_name = event.wait() + stream_redis = Redis(db=8) + generated_text = '' try: - response = generator(msg_to_backend, handler.backend_url) - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - redis.publish(event_id, 'chunk') # Keepalive - except IndexError: - # ???? - continue - - data = { + while True: + stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000) + if not stream_data: + print("No message received in 30 seconds, closing stream.") + yield 'data: [DONE]\n\n' + else: + for r_timestamp, item in stream_data[0][1]: + timestamp = int(r_timestamp.decode('utf-8').split('-')[0]) + data = pickle.loads(item[b'data']) + if data['error']: + yield 'data: [DONE]\n\n' + return + elif data['new']: + response = { "id": f"chatcmpl-{oai_string}", "object": "chat.completion.chunk", - "created": int(time.time()), + "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": { - "content": new + "content": data['new'] }, "finish_reason": None } ] } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - 200, - r_url, - handler.backend_url, - ) - except GeneratorExit: - yield 'data: [DONE]\n\n' - except: - # AttributeError: 'bool' object has no attribute 'iter_content' + generated_text = generated_text + data['new'] + yield f'data: {json.dumps(response)}\n\n' + elif data['completed']: + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + 200, + r_url, + handler.backend_url, + ) + return + except (Exception, GeneratorExit): traceback.print_exc() yield 'data: [DONE]\n\n' finally: - # After completing inference, we need to tell the worker we are finished. - if event_id: # may be None if ratelimited. - redis.publish(event_id, 'finished') - else: - print('event_id was None!') + stream_redis.delete(stream_name) return Response(generate(), mimetype='text/event-stream') except Exception: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index dcbdfc2..cb4aaf5 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -27,7 +27,7 @@ class RedisPriorityQueue: self.name = name self.redis = RedisCustom(name, db=db) - def put(self, item, priority, selected_model): + def put(self, item, priority: int, selected_model: str, do_stream: bool = False): assert item is not None assert priority is not None assert selected_model is not None @@ -43,7 +43,7 @@ class RedisPriorityQueue: return None # reject the request timestamp = time.time() - self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority}) + self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event @@ -106,6 +106,7 @@ class DataEvent: self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): + # TODO: implement timeout for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) @@ -151,9 +152,9 @@ class PriorityQueue: count += queue.get_queued_ip_count(client_ip) return count - def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str): + def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False): queue = RedisPriorityQueue(backend_url) - return queue.put(item, priority, selected_model) + return queue.put(item, priority, selected_model, do_stream) def activity(self): lines = [] diff --git a/llm_server/workers/cleaner.py b/llm_server/workers/cleaner.py new file mode 100644 index 0000000..95a6a78 --- /dev/null +++ b/llm_server/workers/cleaner.py @@ -0,0 +1,32 @@ +import time + +from redis import Redis + +from llm_server.workers.inferencer import STREAM_NAME_PREFIX + + +# NOT NEEDED + +def cleaner(): + r = Redis(db=8) + stream_info = {} + + while True: + all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*') + processed_streams = [] + for stream in all_streams: + stream = stream.decode() + current_size = r.xlen(stream) + + # If the stream is new or its size has changed, update the size and time in the dictionary + if stream not in stream_info or current_size != stream_info[stream]['size']: + stream_info[stream] = {'size': current_size, 'time': time.time()} + processed_streams.append(stream) + else: + # If the size hasn't changed for 5 minutes, delete the stream + if time.time() - stream_info[stream]['time'] >= 300: + r.delete(stream) + print(f"Stream '{stream}' deleted due to inactivity.") + del stream_info[stream] + + time.sleep(60) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index f41d0d3..a8a73e0 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,90 +1,88 @@ -import queue +import json +import pickle import threading -import time import traceback from uuid import uuid4 -from redis.client import PubSub +from redis import Redis -from llm_server import opts from llm_server.cluster.cluster_config import cluster_config -from llm_server.custom_redis import RedisCustom, redis +from llm_server.custom_redis import RedisCustom from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count +stream_redis = Redis(db=8) -class ListenerThread(threading.Thread): - def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event): - threading.Thread.__init__(self) - self.pubsub = pubsub - self.listener_queue = listener_queue - self.stop_event = stop_event +STREAM_NAME_PREFIX = 'stream' - def run(self): - while not self.stop_event.is_set(): - message = self.pubsub.get_message() - if message: - self.listener_queue.put(message) - time.sleep(0.1) + +def get_stream_name(name: str): + return f'{STREAM_NAME_PREFIX}:{name}' + + +def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str): + prompt = msg_to_backend['prompt'] + stream_name = get_stream_name(stream_name) + stream_redis.delete(get_stream_name(stream_name)) # be extra sure + try: + response = generator(msg_to_backend, backend_url) + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue + stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})}) + except Exception as e: + stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) + traceback.print_exc() + finally: + # Publish final message to Redis stream + stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})}) def worker(backend_url): status_redis = RedisCustom('worker_status') - worker_id = uuid4() + worker_id = str(uuid4()) status_redis.setp(str(worker_id), None) redis_queue = RedisPriorityQueue(backend_url) while True: - (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get() + (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get() backend_info = cluster_config.get_backend(backend_url) - pubsub = redis.pubsub() - pubsub.subscribe(event_id) - stop_event = threading.Event() - q = queue.Queue() - listener = ListenerThread(pubsub, q, stop_event) - listener.start() - if not backend_info['online']: - redis.publish(event_id, 'offline') + # TODO: communicate to caller + # redis.publish(event_id, 'offline') return if not selected_model: selected_model = backend_info['model'] + stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) - - status_redis.setp(str(worker_id), (backend_url, client_ip)) + status_redis.setp(str(worker_id), ('generating', client_ip)) try: - if not request_json_body: - # This was a dummy request from the streaming handlers. - # The worker will let the handler do the streaming instead - # of the worker. The worker will block until the handler - # is finished. Since a lot of ratelimiting and stats are - # based off the number of active workers, we must keep - # the generation based off the workers. - start_time = time.time() - redis.publish(event_id, 'begin') - while True: - status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip)) - - try: - item = q.get(timeout=30) - except queue.Empty: - print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model) - status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip)) - break - - if time.time() - start_time >= opts.backend_generate_request_timeout: - status_redis.setp(str(worker_id), ('streaming timed out', client_ip)) - print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model) - break - if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': - status_redis.setp(str(worker_id), ('streaming completed', client_ip)) - break + if do_stream: + event = DataEvent(event_id) + event.set(get_stream_name(worker_id)) + msg_to_backend = { + **parameters, + 'prompt': request_json_body['prompt'], + 'stream': True, + } + inference_do_stream(worker_id, msg_to_backend, backend_url) else: - status_redis.setp(str(worker_id), ('generating', client_ip)) # Normal inference (not streaming). success, response, error_msg = generator(request_json_body, backend_url) event = DataEvent(event_id) @@ -92,8 +90,6 @@ def worker(backend_url): except: traceback.print_exc() finally: - stop_event.set() # make sure to stop the listener thread - listener.join() decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url) status_redis.setp(str(worker_id), None) diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index f19ce1c..0e47c02 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -2,7 +2,6 @@ import time from threading import Thread from llm_server import opts -from llm_server.cluster.stores import redis_running_models from llm_server.cluster.worker import cluster_worker from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.inferencer import start_workers