From e8964fcfd29675c2d3dde5b70395f6f48cbe9309 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 5 Oct 2023 21:37:18 -0600 Subject: [PATCH] fix the queue?? --- daemon.py | 13 +-- llm_server/config/load.py | 4 + llm_server/routes/openai/chat_completions.py | 111 ++++++++++--------- llm_server/routes/openai/completions.py | 105 +++++++++--------- llm_server/routes/queue.py | 64 ++++++++++- llm_server/routes/request_handler.py | 4 +- llm_server/routes/v1/generate_stream.py | 2 +- llm_server/workers/inferencer.py | 45 +++----- llm_server/workers/printer.py | 2 - llm_server/workers/threader.py | 2 +- 10 files changed, 195 insertions(+), 157 deletions(-) diff --git a/daemon.py b/daemon.py index 0fa3601..35c1d59 100644 --- a/daemon.py +++ b/daemon.py @@ -3,13 +3,12 @@ import sys import time from pathlib import Path +from redis import Redis + from llm_server.cluster.cluster_config import cluster_config -from llm_server.cluster.redis_cycle import redis_cycler_db -from llm_server.cluster.stores import redis_running_models from llm_server.config.load import load_config, parse_backends from llm_server.custom_redis import redis from llm_server.database.create import create_db -from llm_server.routes.queue import priority_queue from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.threader import start_background @@ -21,11 +20,8 @@ else: config_path = Path(script_path, 'config', 'config.yml') if __name__ == "__main__": - flushed_keys = redis.flush() - print('Flushed', len(flushed_keys), 'keys from Redis.') - - redis_cycler_db.flushall() - redis_running_models.flush() + Redis().flushall() + print('Flushed Redis.') success, config, msg = load_config(config_path) if not success: @@ -34,7 +30,6 @@ if __name__ == "__main__": create_db() - priority_queue.flush() cluster_config.clear() cluster_config.load(parse_backends(config)) diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 2847265..cc3250c 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -3,11 +3,13 @@ import sys import openai +import llm_server from llm_server import opts from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import get_number_of_rows +from llm_server.routes.queue import PriorityQueue def load_config(config_path): @@ -54,6 +56,8 @@ def load_config(config_path): for item in config['cluster']: opts.cluster_workers += item['concurrent_gens'] + llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) + if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 6e1fdf5..bcbd24c 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -74,7 +74,7 @@ def openai_chat_completions(): event = None if not handler.is_client_ratelimited(): # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, @@ -107,63 +107,64 @@ def openai_chat_completions(): oai_string = generate_oai_string(30) def generate(): - response = generator(msg_to_backend, handler.backend_url) - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + try: + response = generator(msg_to_backend, handler.backend_url) + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) + data = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) + finally: + # After completing inference, we need to tell the worker we + # are finished. + if event_id: # may be None if ratelimited. + redis.publish(event_id, 'finished') + else: + print('event_id was None!') return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500 - finally: - # After completing inference, we need to tell the worker we - # are finished. - if event_id: # may be None if ratelimited. - redis.publish(event_id, 'finished') - else: - print('event_id was None!') diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 9c42cf6..8b5d987 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -102,7 +102,7 @@ def openai_completions(): event = None if not handler.is_client_ratelimited(): # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, @@ -135,61 +135,62 @@ def openai_completions(): oai_string = generate_oai_string(30) def generate(): - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + try: + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"cmpl-{oai_string}", - "object": "text_completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time + data = { + "id": f"cmpl-{oai_string}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) + finally: + if event_id: + redis.publish(event_id, 'finished') + else: + print('event_id was None!') return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500 - finally: - if event_id: - redis.publish(event_id, 'finished') - else: - print('event_id was None!') diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index b075ead..d88ed45 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -1,10 +1,12 @@ import json import pickle import time +from typing import Tuple from uuid import uuid4 from redis import Redis +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit @@ -20,7 +22,7 @@ def decrement_ip_count(client_ip: str, redis_key): class RedisPriorityQueue: - def __init__(self, name: str = 'priority_queue', db: int = 12): + def __init__(self, name, db: int = 12): self.redis = RedisCustom(name, db=db) def put(self, item, priority, selected_model): @@ -98,9 +100,6 @@ class DataEvent: return pickle.loads(item['data']) -priority_queue = RedisPriorityQueue() - - def update_active_workers(key: str, operation: str): if operation == 'incr': redis.incr(f'active_gen_workers:{key}') @@ -118,3 +117,60 @@ def incr_active_workers(selected_model: str, backend_url: str): def decr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'decr') update_active_workers(backend_url, 'decr') + + +class PriorityQueue: + def __init__(self, backends: list = None): + """ + Only have to load the backends once. + :param backends: + """ + self.redis = Redis(host='localhost', port=6379, db=9) + if backends: + for item in backends: + self.redis.lpush('backends', item) + + def get_backends(self): + return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)] + + def get_queued_ip_count(self, client_ip: str): + count = 0 + for backend_url in self.get_backends(): + queue = RedisPriorityQueue(backend_url) + count += queue.get_queued_ip_count(client_ip) + return count + + def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str): + queue = RedisPriorityQueue(backend_url) + return queue.put(item, priority, selected_model) + + def len(self, model_name): + count = 0 + backends_with_models = [] + for k in self.get_backends(): + info = cluster_config.get_backend(k) + if info.get('model') == model_name: + backends_with_models.append(k) + for backend_url in backends_with_models: + queue = RedisPriorityQueue(backend_url) + count += queue.len(model_name) + return count + + def __len__(self): + count = 0 + for backend_url in self.get_backends(): + queue = RedisPriorityQueue(backend_url) + count += len(queue) + return count + + def flush(self): + for k in self.redis.keys(): + q = json.loads(self.redis.get(k)) + q.flush() + self.redis.set(k, json.dumps(q)) + + def flush_db(self): + self.redis.flushdb() + + +priority_queue = PriorityQueue() diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index ef5aa34..a048df7 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -7,7 +7,7 @@ from flask import Response, request from llm_server import opts from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.custom_redis import redis -from llm_server.database.database import get_token_ratelimit, do_db_log +from llm_server.database.database import get_token_ratelimit from llm_server.database.log_to_db import log_to_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend @@ -131,7 +131,7 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model) + event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model) else: event = None diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index e3818c2..c55e36f 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -122,7 +122,7 @@ def do_stream(ws, model_name): event = None if not handler.is_client_ratelimited(): # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 4a1e61f..324c13a 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,29 +1,24 @@ import threading import time +from uuid import uuid4 -from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend -from llm_server.custom_redis import redis +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis, RedisCustom from llm_server.llm.generator import generator -from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue +from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, RedisPriorityQueue, PriorityQueue, priority_queue -def worker(): +def worker(backend_url): + queue = RedisPriorityQueue(backend_url) while True: - (request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get() - if not backend_url: - backend_url = get_a_cluster_backend(selected_model) - else: - backend_url = cluster_config.validate_backend(backend_url) + (request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get() backend_info = cluster_config.get_backend(backend_url) - if not selected_model: selected_model = backend_info['model'] increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) - need_to_wait(backend_url) - try: if not request_json_body: # This was a dummy request from the streaming handlers. @@ -37,7 +32,6 @@ def worker(): redis.publish(event_id, 'begin') for item in pubsub.listen(): if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': - # Once the handler is complete, move on. break time.sleep(0.1) else: @@ -50,23 +44,12 @@ def worker(): decr_active_workers(selected_model, backend_url) -def start_workers(num_workers: int): +def start_workers(cluster: dict): i = 0 - for _ in range(num_workers): - t = threading.Thread(target=worker) - t.daemon = True - t.start() - i += 1 + for item in cluster: + for _ in range(item['concurrent_gens']): + t = threading.Thread(target=worker, args=(item['backend_url'],)) + t.daemon = True + t.start() + i += 1 print(f'Started {i} inference workers.') - - -def need_to_wait(backend_url: str): - # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) - concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1) - s = time.time() - while active_workers >= concurrent_gens: - time.sleep(0.01) - e = time.time() - if e - s > 0.1: - print(f'Worker was delayed {e - s} seconds.') diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index 7be02d7..cf691c1 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -25,6 +25,4 @@ def console_printer(): processing_count += redis.get(k, default=0, dtype=int) backends = [k for k, v in cluster_config.all().items() if v['online']] logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') - priority_queue.print_all_items() - print('============================') time.sleep(1) diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index dbdc8e0..1f5266f 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -20,7 +20,7 @@ def cache_stats(): def start_background(): - start_workers(opts.cluster_workers) + start_workers(opts.cluster) t = Thread(target=main_background_thread) t.daemon = True