From f7e9687527333a6e7e1fb479c8d30d2b8372e6f7 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 1 Oct 2023 16:04:53 -0600 Subject: [PATCH] finish openai endpoints --- README.md | 4 +- llm_server/cluster/backend.py | 7 +- llm_server/llm/generator.py | 6 +- llm_server/llm/llm_backend.py | 6 +- llm_server/llm/openai/transform.py | 34 +-- llm_server/llm/vllm/tokenize.py | 44 ++-- llm_server/routes/ooba_request_handler.py | 11 +- llm_server/routes/openai/chat_completions.py | 153 +++++++----- llm_server/routes/openai/completions.py | 249 ++++++++++--------- llm_server/routes/openai_request_handler.py | 24 +- llm_server/routes/request_handler.py | 2 +- llm_server/routes/v1/generate_stream.py | 4 + llm_server/routes/v1/info.py | 1 + llm_server/workers/inferencer.py | 4 +- requirements.txt | 4 +- 15 files changed, 311 insertions(+), 242 deletions(-) diff --git a/README.md b/README.md index c95e083..ccfaaf4 100644 --- a/README.md +++ b/README.md @@ -43,12 +43,10 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database. ### Use -If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. +If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database. Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app` - - ### To Do - [x] Implement streaming diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index f95970c..61061bb 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -14,8 +14,11 @@ def test_backend(backend_url: str, test_prompt: bool = False): "temperature": 0, "max_new_tokens": 3, } - success, response, err = generator(data, backend_url, timeout=10) - if not success or not response or err: + try: + success, response, err = generator(data, backend_url, timeout=10) + if not success or not response or err: + return False, {} + except: return False, {} i = get_info(backend_url, backend_info['mode']) if not i.get('model'): diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index f05b37c..c924d38 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,12 +1,14 @@ from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config def generator(request_json_body, cluster_backend, timeout: int = None): - if opts.mode == 'oobabooga': + mode = cluster_config.get_backend(cluster_backend)['mode'] + if mode == 'ooba': # from .oobabooga.generate import generate # return generate(request_json_body) raise NotImplementedError - elif opts.mode == 'vllm': + elif mode == 'vllm': from .vllm.generate import generate return generate(request_json_body, cluster_backend, timeout=timeout) else: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index ccc8db8..2ac2beb 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -12,6 +12,7 @@ class LLMBackend: def __init__(self, backend_url: str): self.backend_url = backend_url + self.backend_info = cluster_config.get_backend(self.backend_url) def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError @@ -44,8 +45,7 @@ class LLMBackend: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: prompt_len = get_token_count(prompt, self.backend_url) - token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings'] + token_limit = self.backend_info['model_config']['max_position_embeddings'] if prompt_len > token_limit - 10: - model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str) - return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size' + return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size' return True, None diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 0100c7f..4cf2951 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -20,19 +20,17 @@ def generate_oai_string(length=24): def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]: - tokenizer = tiktoken.get_encoding("cl100k_base") - - def get_token_count_tiktoken_thread(msg): - return len(tokenizer.encode(msg["content"])) + def get_token_count_thread(msg): + return get_token_count(msg["content"], backend_url) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt)) + token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) - formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens + formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens # If total tokens exceed the limit, start trimming - if total_tokens > context_token_limit: + if total_tokens + formatting_tokens > context_token_limit: while True: while total_tokens + formatting_tokens > context_token_limit: # Calculate the index to start removing messages from @@ -45,15 +43,11 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): break - def get_token_count_thread(msg): - return get_token_count(msg["content"], backend_url) - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens - if total_tokens + formatting_tokens > context_token_limit: # Start over, but this time calculate the token count using the backend with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: @@ -65,11 +59,7 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str: tokenizer = tiktoken.get_encoding("cl100k_base") - - def get_token_count_tiktoken_thread(msg): - return len(tokenizer.encode(msg)) - - token_count = get_token_count_tiktoken_thread(prompt) + token_count = get_token_count(prompt, backend_url) # If total tokens exceed the limit, start trimming if token_count > context_token_limit: @@ -80,21 +70,17 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) while remove_index < len(prompt): prompt = prompt[:remove_index] + prompt[remove_index + 100:] - token_count = get_token_count_tiktoken_thread(prompt) + token_count = len(tokenizer.encode(prompt)) if token_count <= context_token_limit or remove_index == len(prompt): break - def get_token_count_thread(msg): - return get_token_count(msg, backend_url) - - token_count = get_token_count_thread(prompt) - + token_count = get_token_count(prompt, backend_url) if token_count > context_token_limit: # Start over, but this time calculate the token count using the backend - token_count = get_token_count_thread(prompt) + token_count = get_token_count(prompt, backend_url) else: break - + print(token_count) return prompt diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 5cad1a4..d51b1de 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -1,29 +1,35 @@ -import requests +import asyncio + +import aiohttp import tiktoken from llm_server import opts -from llm_server.cluster.cluster_config import cluster_config def tokenize(prompt: str, backend_url: str) -> int: assert backend_url if not prompt: - # The tokenizers have issues when the prompt is None. return 0 - tokenizer = tiktoken.get_encoding("cl100k_base") - token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings'] - # First we tokenize it locally to determine if it's worth sending it to the backend. - initial_estimate = len(tokenizer.encode(prompt)) - if initial_estimate <= token_limit + 200: - try: - r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) - j = r.json() - return j['length'] - except Exception as e: - print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') - return len(tokenizer.encode(prompt)) + 10 - else: - # If the result was greater than our context size, return the estimate. - # We won't be sending it through the backend so it does't need to be accurage. - return initial_estimate + async def run(): + tokenizer = tiktoken.get_encoding("cl100k_base") + + async def send_chunk(chunk): + try: + async with session.post(f'{backend_url}/tokenize', json={'input': chunk}, verify_ssl=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) as response: + j = await response.json() + return j['length'] + except Exception as e: + print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') + return len(tokenizer.encode(chunk)) + 10 + + chunk_size = 300 + chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)] + + async with aiohttp.ClientSession() as session: + tasks = [send_chunk(chunk) for chunk in chunks] + lengths = await asyncio.gather(*tasks) + + return sum(lengths) + + return asyncio.run(run()) diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index a272960..909848e 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -13,7 +13,7 @@ class OobaRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def handle_request(self): + def handle_request(self, return_ok: bool = True): assert not self.used request_valid, invalid_response = self.validate_request() @@ -25,14 +25,19 @@ class OobaRequestHandler(RequestHandler): llm_request = {**self.parameters, 'prompt': prompt} _, backend_response = self.generate_response(llm_request) - return backend_response + if return_ok: + # Always return 200 so ST displays our error messages + return backend_response[0], 200 + else: + # The OpenAI route needs to detect 429 errors. + return backend_response def handle_ratelimited(self, do_log: bool = True): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) - return backend_response[0], 200 # We only return the response from handle_error(), not the error code + return backend_response[0], 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 0e716e9..e00d665 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -8,10 +8,11 @@ from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler +from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator -from ...llm.openai.oai_to_vllm import oai_to_vllm +from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -24,11 +25,6 @@ def openai_chat_completions(): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) - - if handler.cluster_backend_info['mode'] != 'vllm': - # TODO: implement other backends - raise NotImplementedError - if not request_json_body.get('stream'): try: return handler.handle_request() @@ -37,30 +33,51 @@ def openai_chat_completions(): return 'Internal server error', 500 else: if not opts.enable_streaming: - # TODO: return a proper OAI error message - return 'disabled', 401 + return 'DISABLED', 401 + + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: - handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) + else: + handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) response_status_code = 0 start_time = time.time() + request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response else: - if opts.openai_silent_trim: - oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) - else: - oai_messages = handler.request.json['messages'] - - handler.prompt = transform_messages_to_prompt(oai_messages) - handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode']) msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } + + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) + if not event: + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + response_status_code, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() + + # Wait for a worker to get our request and discard it. + _, _, _ = event.wait() + try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) @@ -69,57 +86,61 @@ def openai_chat_completions(): oai_string = generate_oai_string(30) def generate(): - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + try: + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' + data = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - - log_prompt( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) + finally: + # The worker incremented it, we'll decrement it. + decrement_ip_count(handler.client_ip, 'processing_ips') + decr_active_workers(handler.selected_model, handler.backend_url) return Response(generate(), mimetype='text/event-stream') - except: - # TODO: simulate OAI here - raise Exception + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 41d1d3b..7bed9fa 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -8,6 +8,7 @@ from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler +from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm import get_token_count @@ -24,80 +25,98 @@ def openai_completions(): if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - try: - handler = OobaRequestHandler(incoming_request=request) + handler = OobaRequestHandler(incoming_request=request) - if handler.cluster_backend_info['mode'] != 'vllm': - # TODO: implement other backends - raise NotImplementedError + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError - invalid_oai_err_msg = validate_oai(handler.request_json_body) - if invalid_oai_err_msg: - return invalid_oai_err_msg - handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) - # Convert parameters to the selected backend type - if opts.openai_silent_trim: - handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) - else: - # The handle_request() call below will load the prompt so we don't have - # to do anything else here. - pass + if opts.openai_silent_trim: + handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + else: + # The handle_request() call below will load the prompt so we don't have + # to do anything else here. + pass - if not request_json_body.get('stream'): - response, status_code = handler.handle_request() - if status_code != 200: - return status_code - output = response.json['results'][0]['text'] + if not request_json_body.get('stream'): + response, status_code = handler.handle_request(return_ok=False) + if status_code == 429: + return handler.handle_ratelimited() + output = response.json['results'][0]['text'] - # TODO: async/await - prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) - response_tokens = get_token_count(output, handler.backend_url) - running_model = redis.get('running_model', 'ERROR', dtype=str) + # TODO: async/await + prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) + response_tokens = get_token_count(output, handler.backend_url) + running_model = redis.get('running_model', 'ERROR', dtype=str) - response = jsonify({ - "id": f"cmpl-{generate_oai_string(30)}", - "object": "text_completion", - "created": int(time.time()), - "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), - "choices": [ - { - "text": output, - "index": 0, - "logprobs": None, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": response_tokens, - "total_tokens": prompt_tokens + response_tokens + response = jsonify({ + "id": f"cmpl-{generate_oai_string(30)}", + "object": "text_completion", + "created": int(time.time()), + "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), + "choices": [ + { + "text": output, + "index": 0, + "logprobs": None, + "finish_reason": "stop" } - }) + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": response_tokens, + "total_tokens": prompt_tokens + response_tokens + } + }) - stats = redis.get('proxy_stats', dtype=dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] - return response, 200 + stats = redis.get('proxy_stats', dtype=dict) + if stats: + response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response, 200 + else: + if not opts.enable_streaming: + return 'DISABLED', 401 + + response_status_code = 0 + start_time = time.time() + + request_valid, invalid_response = handler.validate_request() + if not request_valid: + return invalid_response else: - if not opts.enable_streaming: - # TODO: return a proper OAI error message - return 'disabled', 401 + handler.prompt = handler.request_json_body['prompt'] + msg_to_backend = { + **handler.parameters, + 'prompt': handler.prompt, + 'stream': True, + } - response_status_code = 0 - start_time = time.time() + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) + if not event: + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + response_status_code, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() - request_valid, invalid_response = handler.validate_request() - if not request_valid: - # TODO: simulate OAI here - raise Exception('TODO: simulate OAI here') - else: - handler.prompt = handler.request_json_body['prompt'] - msg_to_backend = { - **handler.parameters, - 'prompt': handler.prompt, - 'stream': True, - } + # Wait for a worker to get our request and discard it. + _, _, _ = event.wait() + + try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url @@ -105,57 +124,61 @@ def openai_completions(): oai_string = generate_oai_string(30) def generate(): - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + try: + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"chatcmpl-{oai_string}", - "object": "text_completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' + data = { + "id": f"cmpl-{oai_string}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - - log_prompt( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) + finally: + # The worker incremented it, we'll decrement it. + decrement_ip_count(handler.client_ip, 'processing_ips') + decr_active_workers(handler.selected_model, handler.backend_url) return Response(generate(), mimetype='text/event-stream') - except Exception: - traceback.print_exc() - return 'Internal Server Error', 500 + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 6b9ff98..541c2c9 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -10,8 +10,9 @@ from flask import Response, jsonify, make_response import llm_server from llm_server import opts +from llm_server.cluster.model_choices import get_model_choices from llm_server.custom_redis import redis -from llm_server.database.database import is_api_key_moderated +from llm_server.database.database import is_api_key_moderated, log_prompt from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.routes.request_handler import RequestHandler @@ -70,9 +71,24 @@ class OpenAIRequestHandler(RequestHandler): return backend_response, backend_response_status_code def handle_ratelimited(self, do_log: bool = True): - # TODO: return a simulated OpenAI error message - # Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. - return 'Ratelimited', 429 + _, default_backend_info = get_model_choices() + w = int(default_backend_info['estimated_wait']) if default_backend_info['estimated_wait'] > 0 else 2 + response = jsonify({ + "error": { + "message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.", + "type": "rate_limit_exceeded", + "param": None, + "code": None + } + }) + response.headers['x-ratelimit-limit-requests'] = '2' + response.headers['x-ratelimit-remaining-requests'] = '0' + response.headers['x-ratelimit-reset-requests'] = f"{w}s" + + if do_log: + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) + + return response, 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: return jsonify({ diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index ff83e76..a595b89 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -209,7 +209,7 @@ class RequestHandler: if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0: return False else: - print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.') + print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.') return True def handle_request(self) -> Tuple[flask.Response, int]: diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index c9b9c0d..9417151 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -115,6 +115,10 @@ def do_stream(ws, model_name): err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return + + # Wait for a worker to get our request and discard it. + _, _, _ = event.wait() + try: response = generator(llm_request, handler.backend_url) if not response: diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index df4e3be..6e37720 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -6,6 +6,7 @@ from llm_server.custom_redis import flask_cache from . import bp from ... import opts from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model +from ...cluster.cluster_config import cluster_config @bp.route('/v1/model', methods=['GET']) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index e92052e..07de40e 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -21,8 +21,10 @@ def worker(): incr_active_workers(selected_model, backend_url) if not request_json_body: - # This was a dummy request from the websocket handler. + # This was a dummy request from the websocket handlers. # We're going to let the websocket handler decrement processing_ips and active_gen_workers. + event = DataEvent(event_id) + event.set((True, None, None)) continue try: diff --git a/requirements.txt b/requirements.txt index 28e818f..bcd1eeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ openai~=0.28.0 urllib3~=2.0.4 flask-sock==0.6.0 gunicorn==21.2.0 -redis==5.0.1 \ No newline at end of file +redis==5.0.1 +aiohttp==3.8.5 +asyncio==3.4.3 \ No newline at end of file