From e5fbc9545dbc80d3bb9e2b82b67175289d74a2b7 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 27 Sep 2023 21:15:54 -0600 Subject: [PATCH] add ratelimiting to websocket streaming endpoint, fix queue not decrementing IP requests, add console printer --- llm_server/llm/generator.py | 5 +- llm_server/routes/queue.py | 45 ++----- llm_server/routes/request_handler.py | 7 +- llm_server/routes/v1/generate_stream.py | 152 +++++++++++++----------- llm_server/workers/__init__.py | 0 llm_server/workers/blocking.py | 58 +++++++++ llm_server/workers/printer.py | 29 +++++ server.py | 11 +- 8 files changed, 191 insertions(+), 116 deletions(-) create mode 100644 llm_server/workers/__init__.py create mode 100644 llm_server/workers/blocking.py create mode 100644 llm_server/workers/printer.py diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 3aca56b..5dd2093 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -3,8 +3,9 @@ from llm_server import opts def generator(request_json_body): if opts.mode == 'oobabooga': - from .oobabooga.generate import generate - return generate(request_json_body) + # from .oobabooga.generate import generate + # return generate(request_json_body) + raise NotImplementedError elif opts.mode == 'vllm': from .vllm.generate import generate r = generate(request_json_body) diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index ebe6641..d6c84f5 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -1,13 +1,11 @@ import json import pickle -import threading import time from uuid import uuid4 from redis import Redis from llm_server import opts -from llm_server.llm.generator import generator from llm_server.routes.cache import redis @@ -39,6 +37,8 @@ class RedisPriorityQueue: # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) + if ip_count: + ip_count = int(ip_count) if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request @@ -52,11 +52,10 @@ class RedisPriorityQueue: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) - client_ip = item[1][1] - # Decrement the count for this IP + client_ip = item[0][1] self.decrement_ip_count(client_ip, 'queued_ip_count') return item - time.sleep(1) # wait for an item to be added to the queue + time.sleep(0.5) # wait for something to be added to the queue def increment_ip_count(self, ip, key): self.redis.hincrby(key, ip, 1) @@ -67,6 +66,13 @@ class RedisPriorityQueue: def __len__(self): return self.redis.zcard('queue') + def get_ip_count(self, client_ip: str): + x = self.redis.hget('queued_ip_count', client_ip) + if x: + return x.decode('utf-8') + else: + return x + class DataEvent: def __init__(self, event_id=None): @@ -85,32 +91,3 @@ class DataEvent: priority_queue = RedisPriorityQueue() - - -def worker(): - while True: - (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() - - increment_ip_count(client_ip, 'processing_ips') - redis.incr('active_gen_workers') - - try: - start_time = time.time() - success, response, error_msg = generator(request_json_body) - end_time = time.time() - - elapsed_time = end_time - start_time - redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time))) - - event = DataEvent(event_id) - event.set((success, response, error_msg)) - finally: - decrement_ip_count(client_ip, 'processing_ips') - redis.decr('active_gen_workers') - - -def start_workers(num_workers: int): - for _ in range(num_workers): - t = threading.Thread(target=worker) - t.daemon = True - t.start() diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 222f479..a598816 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -12,7 +12,6 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token from llm_server.routes.cache import redis -from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue @@ -134,7 +133,6 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) else: event = None @@ -193,14 +191,11 @@ class RequestHandler: return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: - print('queued_ip_count', redis.get_dict('queued_ip_count')) - print('processing_ips', redis.get_dict('processing_ips')) - - queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0: return False else: + print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.') return True def handle_request(self) -> Tuple[flask.Response, int]: diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 288b1a2..2e39903 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -2,12 +2,14 @@ import json import threading import time import traceback +from typing import Union from flask import request -from ..helpers.client import format_sillytavern_err +from ..cache import redis from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler +from ..queue import decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator @@ -20,8 +22,31 @@ from ...stream import sock @sock.route('/api/v1/stream') def stream(ws): + def send_err_and_quit(quitting_err_msg): + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': 0, + 'text': quitting_err_msg + })) + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': 1 + })) + ws.close() + log_in_bg(quitting_err_msg, is_error=True) + + def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): + def background_task_exception(): + generated_tokens = tokenize(generated_text_bg) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) + + # TODO: use async/await instead of threads + thread = threading.Thread(target=background_task_exception) + thread.start() + thread.join() + if not opts.enable_streaming: - return 'disabled', 401 + return 'Streaming is disabled', 401 r_headers = dict(request.headers) r_url = request.url @@ -30,12 +55,7 @@ def stream(ws): message = ws.receive() request_valid_json, request_json_body = validate_json(message) if not request_valid_json or not request_json_body.get('prompt'): - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': 'Invalid JSON' - })) - message_num += 1 + return 'Invalid JSON', 400 else: if opts.mode != 'vllm': # TODO: implement other backends @@ -50,36 +70,44 @@ def stream(ws): input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() - request_valid, invalid_response = handler.validate_request(prompt=input_prompt) - if not request_valid: - err_msg = invalid_response[0].json['results'][0]['text'] - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': 0, - 'text': err_msg - })) - ws.send(json.dumps({ - 'event': 'stream_end', - 'message_num': 1 - })) - ws.close() # this is important if we encountered and error and exited early. - def background_task(): - log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() + err_msg = None + if handler.is_client_ratelimited(): + r, _ = handler.handle_ratelimited() + err_msg = r.json['results'][0]['text'] else: - msg_to_backend = { - **handler.parameters, - 'prompt': input_prompt, - 'stream': True, - } - try: - response = generator(msg_to_backend) + request_valid, invalid_response = handler.validate_request(prompt=input_prompt) + if not request_valid: + err_msg = invalid_response[0].json['results'][0]['text'] + if err_msg: + send_err_and_quit(err_msg) + return + llm_request = { + **handler.parameters, + 'prompt': input_prompt, + 'stream': True, + } + + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) + if not event: + r, _ = handler.handle_ratelimited() + err_msg = r.json['results'][0]['text'] + send_err_and_quit(err_msg) + return + try: + response = generator(llm_request) + if not response: + error_msg = 'Failed to reach backend while streaming.' + print('Streaming failed:', error_msg) + msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': msg + })) + else: # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code @@ -88,14 +116,6 @@ def stream(ws): partial_response = b'' -# TODO: handle when the backend is offline -# Traceback (most recent call last): -# File "/srv/server/local-llm-server/llm_server/routes/v1/generate_stream.py", line 91, in stream -# for chunk in response.iter_content(chunk_size=1): -# ^^^^^^^^^^^^^^^^^^^^^ -# AttributeError: 'NoneType' object has no attribute 'iter_content' - - for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): @@ -116,6 +136,7 @@ def stream(ws): 'text': new })) except: + # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) @@ -128,41 +149,32 @@ def stream(ws): if not chunk: break + if response: response.close() - end_time = time.time() - elapsed_time = end_time - start_time - - def background_task_success(): - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_success) - thread.start() - thread.join() - except: - generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8') - traceback.print_exc() - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) - - def background_task_exception(): - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() + end_time = time.time() + elapsed_time = end_time - start_time + log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) + except: + traceback.print_exc() + generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': generated_text + })) + log_in_bg(generated_text, is_error=True, status_code=response_status_code) + finally: + # The worker incremented it, we'll decrement it. + decrement_ip_count(handler.client_ip, 'processing_ips') + redis.decr('active_gen_workers') try: ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) except: + # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) diff --git a/llm_server/workers/__init__.py b/llm_server/workers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py new file mode 100644 index 0000000..91a04c4 --- /dev/null +++ b/llm_server/workers/blocking.py @@ -0,0 +1,58 @@ +import json +import threading +import time + +from llm_server import opts +from llm_server.llm.generator import generator +from llm_server.routes.cache import redis +from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue + + +def worker(): + while True: + need_to_wait() + (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() + need_to_wait() + + increment_ip_count(client_ip, 'processing_ips') + redis.incr('active_gen_workers') + + if not request_json_body: + # This was a dummy request from the websocket handler. + # We're going to let the websocket handler decrement processing_ips and active_gen_workers. + continue + + try: + start_time = time.time() + success, response, error_msg = generator(request_json_body) + end_time = time.time() + + elapsed_time = end_time - start_time + # redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time))) + + event = DataEvent(event_id) + event.set((success, response, error_msg)) + finally: + decrement_ip_count(client_ip, 'processing_ips') + redis.decr('active_gen_workers') + + +def start_workers(num_workers: int): + i = 0 + for _ in range(num_workers): + t = threading.Thread(target=worker) + t.daemon = True + t.start() + i += 1 + print(f'Started {i} inference workers.') + + +def need_to_wait(): + # We need to check the number of active workers since the streaming endpoint may be doing something. + active_workers = redis.get('active_gen_workers', int, 0) + s = time.time() + while active_workers >= opts.concurrent_gens: + time.sleep(0.01) + e = time.time() + if e - s > 0.5: + print(f'Worker was delayed {e - s} seconds.') diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py new file mode 100644 index 0000000..aa45ccf --- /dev/null +++ b/llm_server/workers/printer.py @@ -0,0 +1,29 @@ +import logging +import threading +import time + +from llm_server.routes.cache import redis + + +def console_printer(): + logger = logging.getLogger('console_printer') + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + while True: + queued_ip_count = redis.get_dict('queued_ip_count') + queued_ip_count = sum([v for k, v in queued_ip_count.items()]) + processing_ips = redis.get_dict('processing_ips') + processing_count = sum([v for k, v in processing_ips.items()]) + + logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}') + time.sleep(10) + + +def start_console_printer(): + t = threading.Thread(target=console_printer) + t.daemon = True + t.start() diff --git a/server.py b/server.py index 943c745..1ee61b2 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,4 @@ -from redis import Redis +from llm_server.workers.printer import start_console_printer try: import gevent.monkey @@ -16,6 +16,7 @@ from threading import Thread import openai import simplejson as json from flask import Flask, jsonify, render_template, request +from redis import Redis import llm_server from llm_server.database.conn import database @@ -26,6 +27,7 @@ from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.stream import init_socketio +from llm_server.workers.blocking import start_workers # TODO: have the workers handle streaming too # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail @@ -36,6 +38,7 @@ from llm_server.stream import init_socketio # TODO: implement RRD backend loadbalancer option # TODO: have VLLM reject a request if it already has n == concurrent_gens running # TODO: add a way to cancel VLLM gens. Maybe use websockets? +# TODO: use coloredlogs # Lower priority # TODO: the processing stat showed -1 and I had to restart the server @@ -65,7 +68,6 @@ from llm_server.helpers import resolve_path, auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.llm import redis -from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers @@ -166,6 +168,8 @@ def pre_fork(server): redis.set_dict('processing_ips', {}) redis.set_dict('queued_ip_count', {}) + + # Flush the RedisPriorityQueue database. queue_redis = Redis(host='localhost', port=6379, db=15) for key in queue_redis.scan_iter('*'): queue_redis.delete(key) @@ -181,8 +185,7 @@ def pre_fork(server): # Start background processes start_workers(opts.concurrent_gens) - print(f'Started {opts.concurrent_gens} inference workers.') - + start_console_printer() start_moderation_workers(opts.openai_moderation_workers) MainBackgroundThread().start() SemaphoreCheckerThread().start()