diff --git a/llm_server/cluster/model_choices.py b/llm_server/cluster/model_choices.py index 31cd8cb..3df3aea 100644 --- a/llm_server/cluster/model_choices.py +++ b/llm_server/cluster/model_choices.py @@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_ from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue -from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers +from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model # TODO: give this a better name! @@ -30,7 +30,7 @@ def get_model_choices(regen: bool = False): if backend_info.get('average_generation_elapsed_sec'): avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec']) - active_gen_workers = get_active_gen_workers(model) + active_gen_workers = get_active_gen_workers_model(model) proompters_in_queue = priority_queue.len(model) if len(avg_gen_per_worker): diff --git a/llm_server/database/database.py b/llm_server/database/database.py index dec5e98..fc1aa21 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -2,15 +2,15 @@ import json import time import traceback from threading import Thread +from typing import Union -import llm_server from llm_server import opts from llm_server.cluster.cluster_config import cluster_config from llm_server.database.conn import database -from llm_server.llm.vllm import tokenize +from llm_server.llm import get_token_count -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False): +def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): def background_task(): nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error # Try not to shove JSON into the database. @@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe except: pass - prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url) + prompt_tokens = get_token_count(prompt, backend_url) if not is_error: if not response_tokens: - response_tokens = llm_server.llm.get_token_count(response, backend_url) + response_tokens = get_token_count(response, backend_url) else: response_tokens = None diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index 3feb027..ba46635 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -3,6 +3,9 @@ from llm_server.custom_redis import redis def get_token_count(prompt: str, backend_url: str): + assert isinstance(prompt, str) + assert isinstance(backend_url, str) + backend_mode = redis.get('backend_mode', dtype=str) if backend_mode == 'vllm': return vllm.tokenize(prompt, backend_url) diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index ce59e9b..9111389 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -8,11 +8,11 @@ def oai_to_vllm(request_json_body, hashes: bool, mode): request_json_body['stop'] = [] if hashes: - request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) + request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE']) if opts.openai_force_no_hashes: request_json_body['stop'].append('### ') else: - request_json_body['stop'].extend(['\nuser:', '\nassistant:']) + request_json_body['stop'].extend(['user:', 'assistant:']) if request_json_body.get('frequency_penalty', 0) < -2: request_json_body['frequency_penalty'] = -2 diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index bd44ad8..d5a1b71 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -8,6 +8,9 @@ from llm_server import opts def tokenize(prompt: str, backend_url: str) -> int: assert backend_url + assert isinstance(prompt, str) + assert isinstance(backend_url, str) + if not prompt: # The tokenizers have issues when the prompt is None. return 0 diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index e59f255..c46e89f 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -62,7 +62,7 @@ def openai_chat_completions(): } # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_prompt( handler.client_ip, diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index e772842..6904348 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -100,7 +100,7 @@ def openai_completions(): } # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_prompt( handler.client_ip, diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 541c2c9..8664695 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -8,11 +8,11 @@ from uuid import uuid4 import flask from flask import Response, jsonify, make_response -import llm_server from llm_server import opts from llm_server.cluster.model_choices import get_model_choices from llm_server.custom_redis import redis from llm_server.database.database import is_api_key_moderated, log_prompt +from llm_server.llm import get_token_count from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.routes.request_handler import RequestHandler @@ -110,9 +110,8 @@ class OpenAIRequestHandler(RequestHandler): response = re.sub(ANTI_RESPONSE_RE, '', response) response = re.sub(ANTI_CONTINUATION_RE, '', response) - # TODO: async/await - prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url) - response_tokens = llm_server.llm.get_token_count(response, self.backend_url) + prompt_tokens = get_token_count(prompt, self.backend_url) + response_tokens = get_token_count(response, self.backend_url) running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 5d2c6b3..a8a47b1 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -27,7 +27,6 @@ class RedisPriorityQueue: def put(self, item, priority, selected_model): event = DataEvent() - # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: @@ -99,16 +98,20 @@ class DataEvent: priority_queue = RedisPriorityQueue() +def update_active_workers(key: str, operation: str): + if operation == 'incr': + redis.incr(f'active_gen_workers:{key}') + elif operation == 'decr': + redis.decr(f'active_gen_workers:{key}') + if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0: + redis.set(f'active_gen_workers:{key}', 0) + + def incr_active_workers(selected_model: str, backend_url: str): - redis.incr(f'active_gen_workers:{selected_model}') - redis.incr(f'active_gen_workers:{backend_url}') + update_active_workers(selected_model, 'incr') + update_active_workers(backend_url, 'incr') def decr_active_workers(selected_model: str, backend_url: str): - redis.decr(f'active_gen_workers:{selected_model}') - if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0: - redis.set(f'active_gen_workers:{selected_model}', 0) - - redis.decr(f'active_gen_workers:{backend_url}') - if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0: - redis.set(f'active_gen_workers:{backend_url}', 0) + update_active_workers(selected_model, 'decr') + update_active_workers(backend_url, 'decr') diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 7c425dc..4e8b8e4 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -36,6 +36,7 @@ class RequestHandler: self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token) self.backend_url = get_a_cluster_backend(selected_model) self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + self.selected_model = self.cluster_backend_info['model'] if not self.cluster_backend_info.get('mode'): print(selected_model, self.backend_url, self.cluster_backend_info) @@ -43,7 +44,6 @@ class RequestHandler: self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.parameters = None self.used = False - self.selected_model = selected_model redis.zadd('recent_prompters', {self.client_ip: time.time()}) def get_auth_token(self): diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index 9e1f291..7f3b2fe 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -15,13 +15,8 @@ def get_total_proompts(): return count -def get_active_gen_workers(selected_model: str = None, ): - active_gen_workers = redis.get(f'active_gen_workers:{selected_model}') - if active_gen_workers is None: - count = 0 - else: - count = int(active_gen_workers) - return count +def get_active_gen_workers_model(selected_model: str = None): + return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0) def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 9417151..ac148dd 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -11,7 +11,6 @@ from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator -from ...llm.vllm import tokenize from ...sock import sock @@ -45,7 +44,6 @@ def do_stream(ws, model_name): 'event': 'stream_end', 'message_num': 1 })) - ws.close() log_prompt(ip=handler.client_ip, token=handler.token, prompt=input_prompt, @@ -56,7 +54,7 @@ def do_stream(ws, model_name): backend_response_code=response_status_code, request_url=r_url, backend_url=handler.cluster_backend_info, - response_tokens=tokenize(generated_text, handler.backend_url), + response_tokens=None, is_error=True ) @@ -67,195 +65,192 @@ def do_stream(ws, model_name): r_url = request.url message_num = 0 - while ws.connected: - message = ws.receive() - request_valid_json, request_json_body = validate_json(message) + try: + while ws.connected: + message = ws.receive() + request_valid_json, request_json_body = validate_json(message) - if not request_valid_json or not request_json_body.get('prompt'): - ws.close() - return 'Invalid JSON', 400 - else: - if opts.mode != 'vllm': - # TODO: implement other backends - raise NotImplementedError - - auth_failure = require_api_key(request_json_body) - if auth_failure: - ws.close() - return auth_failure - - handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) - generated_text = '' - input_prompt = request_json_body['prompt'] - response_status_code = 0 - start_time = time.time() - - err_msg = None - if handler.is_client_ratelimited(): - r, _ = handler.handle_ratelimited(do_log=False) - err_msg = r.json['results'][0]['text'] + if not request_valid_json or not request_json_body.get('prompt'): + return 'Invalid JSON', 400 else: - request_valid, invalid_response = handler.validate_request(prompt=input_prompt) - if not request_valid: - err_msg = invalid_response[0].json['results'][0]['text'] - if err_msg: - send_err_and_quit(err_msg) - return + if opts.mode != 'vllm': + # TODO: implement other backends + raise NotImplementedError - llm_request = { - **handler.parameters, - 'prompt': input_prompt, - 'stream': True, - } + auth_failure = require_api_key(request_json_body) + if auth_failure: + return auth_failure - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url) - if not event: - r, _ = handler.handle_ratelimited() - err_msg = r.json['results'][0]['text'] - send_err_and_quit(err_msg) - return + handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) + generated_text = '' + input_prompt = request_json_body['prompt'] + response_status_code = 0 + start_time = time.time() - # Wait for a worker to get our request and discard it. - _, _, _ = event.wait() + err_msg = None + if handler.is_client_ratelimited(): + r, _ = handler.handle_ratelimited(do_log=False) + err_msg = r.json['results'][0]['text'] + else: + request_valid, invalid_response = handler.validate_request(prompt=input_prompt) + if not request_valid: + err_msg = invalid_response[0].json['results'][0]['text'] + if err_msg: + send_err_and_quit(err_msg) + return - try: - response = generator(llm_request, handler.backend_url) - if not response: - error_msg = 'Failed to reach backend while streaming.' - print('Streaming failed:', error_msg) - msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] + llm_request = { + **handler.parameters, + 'prompt': input_prompt, + 'stream': True, + } + + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + if not event: + r, _ = handler.handle_ratelimited() + err_msg = r.json['results'][0]['text'] + send_err_and_quit(err_msg) + return + + # Wait for a worker to get our request and discard it. + _, _, _ = event.wait() + + try: + response = generator(llm_request, handler.backend_url) + + if not response: + error_msg = 'Failed to reach backend while streaming.' + print('Streaming failed:', error_msg) + msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': msg + })) + else: + # Be extra careful when getting attributes from the response object + try: + response_status_code = response.status_code + except: + response_status_code = 0 + + partial_response = b'' + + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(input_prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue + try: + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': new + })) + except: + # The has client closed the stream. + if request: + # Cancel the backend? + request.close() + end_time = time.time() + elapsed_time = end_time - start_time + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=None + ) + return + + message_num += 1 + partial_response = b'' # Reset the partial response + + # If there is no more data, break the loop + if not chunk: + break + + end_time = time.time() + elapsed_time = end_time - start_time + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=None, + is_error=not response + ) + except: + traceback.print_exc() + generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, - 'text': msg + 'text': generated_text })) - else: - # Be extra careful when getting attributes from the response object - try: - response_status_code = response.status_code - except: - response_status_code = 0 - - partial_response = b'' - - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(input_prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue - try: - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': new - })) - except: - # The has client closed the stream. - if request: - request.close() - try: - ws.close() - except: - pass - end_time = time.time() - elapsed_time = end_time - start_time - log_prompt(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=elapsed_time, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url, - response_tokens=tokenize(generated_text, handler.backend_url) - ) - - return - - message_num += 1 - partial_response = b'' # Reset the partial response - - # If there is no more data, break the loop - if not chunk: - break - - end_time = time.time() - elapsed_time = end_time - start_time - log_prompt(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=elapsed_time, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url, - response_tokens=tokenize(generated_text, handler.backend_url), - is_error=not response - ) - except: - traceback.print_exc() - generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) - if request: - request.close() - ws.close() - log_prompt(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=None, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url, - response_tokens=tokenize(generated_text, handler.backend_url), - is_error=True - ) - return - finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers(handler.selected_model, handler.backend_url) - try: - ws.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) - except: - # The client closed the stream. - end_time = time.time() - elapsed_time = end_time - start_time - log_prompt(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=elapsed_time, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url, - response_tokens=tokenize(generated_text, handler.backend_url) - ) - try: - ws.close() # this is important if we encountered and error and exited early. - except: - pass + if request: + request.close() + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=None, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=None, + is_error=True + ) + return + finally: + # The worker incremented it, we'll decrement it. + decrement_ip_count(handler.client_ip, 'processing_ips') + decr_active_workers(handler.selected_model, handler.backend_url) + try: + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + except: + # The client closed the stream. + end_time = time.time() + elapsed_time = end_time - start_time + log_prompt(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=None + ) + finally: + try: + # Must close the connection or greenlets will complain. + ws.close() + except: + pass diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 0aff9ac..c5eb12a 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -11,19 +11,23 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip def worker(): while True: (request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get() + if not backend_url: + backend_url = get_a_cluster_backend(selected_model) backend_info = cluster_config.get_backend(backend_url) + # The backend could have died between when the request was + # submitted and now, so let's double check it's still online. if not backend_info['online']: old = backend_url backend_url = get_a_cluster_backend() backend_info = cluster_config.get_backend(backend_url) print(f'Backend {old} offline. Request was redirected to {backend_url}') - del old + del old # gc if not selected_model: selected_model = backend_info['model'] - # This wait time is "invisible", meaning the worker may as + # This wait time will be "invisible", meaning the worker may as # well be still waiting to get an item from the queue. need_to_wait(backend_url) @@ -32,7 +36,8 @@ def worker(): if not request_json_body: # This was a dummy request from the websocket handlers. - # We're going to let the websocket handler decrement processing_ips and active_gen_workers. + # We're going to let the websocket handler decrement + # processing_ips and active_gen_workers. event = DataEvent(event_id) event.set((True, None, None)) continue diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt index bcd1eeb..28e818f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,4 @@ openai~=0.28.0 urllib3~=2.0.4 flask-sock==0.6.0 gunicorn==21.2.0 -redis==5.0.1 -aiohttp==3.8.5 -asyncio==3.4.3 \ No newline at end of file +redis==5.0.1 \ No newline at end of file diff --git a/server.py b/server.py index 1d89ca2..382c7ff 100644 --- a/server.py +++ b/server.py @@ -24,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.sock import init_socketio +# TODO: implement blind RRD controlled via header and only used when there is a queue on the primary backend(s) # TODO: is frequency penalty the same as ooba repetition penalty??? # TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions # TODO: if a backend is at its limit of concurrent requests, choose a different one @@ -93,7 +94,6 @@ create_db() def home(): base_client_api = redis.get('base_client_api', dtype=str) stats = generate_stats() - model_choices, default_backend_info = get_model_choices() if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens: