From fe23a2282faf66e973b4bc9bd594b8a35261507c Mon Sep 17 00:00:00 2001 From: Cyberes Date: Tue, 7 May 2024 17:41:53 -0600 Subject: [PATCH] refactor, add Llm-Disable-Openai header --- README.md | 1 + llm_server/cluster/backend.py | 2 +- llm_server/custom_redis.py | 12 +----------- llm_server/database/database.py | 2 +- llm_server/globals.py | 6 +++++- llm_server/llm/openai/transform.py | 10 ++++++++-- llm_server/routes/openai/chat_completions.py | 2 +- llm_server/routes/openai_request_handler.py | 17 ++++++++++++----- 8 files changed, 30 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 1fbc2d9..e2089d4 100644 --- a/README.md +++ b/README.md @@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`. - [ ] Make sure stats work when starting from an empty database - [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation. - [ ] Add test to verify the OpenAI endpoint works as expected +- [ ] Document the `Llm-Disable-Openai` header \ No newline at end of file diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 20e8140..119c378 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -26,7 +26,7 @@ def get_running_models(): Get all the models that are in the cluster. :return: """ - return [x.decode('utf-8') for x in list(redis_running_models.keys())] + return [x for x in list(redis_running_models.keys())] def is_valid_model(model_name: str) -> bool: diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 3aa338c..05f7745 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -1,10 +1,7 @@ -import logging import pickle -import sys import traceback from typing import Union -import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis @@ -27,13 +24,6 @@ class RedisCustom(Redis): super().__init__() self.redis = Redis(**kwargs) self.prefix = prefix - try: - self.set('____', 1) - except redis_pkg.exceptions.ConnectionError as e: - logger = logging.getLogger('redis') - logger.setLevel(logging.INFO) - logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?') - sys.exit(1) def _key(self, key): return f"{self.prefix}:{key}" @@ -75,7 +65,7 @@ class RedisCustom(Redis): # Delete prefix del p[0] k = ':'.join(p) - # keys.append(k) + keys.append(k) return keys def exists(self, *names: KeyT): diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 90cf219..ffd6296 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -62,7 +62,7 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_ (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) -def is_valid_api_key(api_key): +def is_valid_api_key(api_key: str): with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() diff --git a/llm_server/globals.py b/llm_server/globals.py index 4a755cb..4c99563 100644 --- a/llm_server/globals.py +++ b/llm_server/globals.py @@ -1,6 +1,10 @@ # Read-only global variables -DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" +DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. " + "You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, " + "apologize and suggest the user seek help elsewhere.") +OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.""" + REDIS_STREAM_TIMEOUT = 25000 LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 88681c2..8724e76 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -8,6 +8,7 @@ from typing import Dict, List import tiktoken from llm_server.config.global_config import GlobalConfig +from llm_server.globals import OPENAI_FORMATTING_PROMPT from llm_server.llm import get_token_count ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. @@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) return prompt -def transform_messages_to_prompt(oai_messages): +def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False): + if not disable_openai_handling: + prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}' + else: + prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}' + prompt = prompt + '\n\n' + try: - prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}' for msg in oai_messages: if 'content' not in msg.keys() or 'role' not in msg.keys(): return False diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index ae5f101..f675370 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler from ..queue import priority_queue from ...config.global_config import GlobalConfig from ...database.log_to_db import log_to_db -from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error +from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...logging import create_logger diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 5da99e5..8d5f6d9 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -6,17 +6,18 @@ from typing import Tuple from uuid import uuid4 import flask -from flask import Response, jsonify, make_response +from flask import Response, jsonify, make_response, request from llm_server.cluster.backend import get_model_choices from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import redis -from llm_server.database.database import is_api_key_moderated +from llm_server.database.database import is_api_key_moderated, is_valid_api_key from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.logging import create_logger +from llm_server.routes.auth import parse_token from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results @@ -33,21 +34,27 @@ class OpenAIRequestHandler(RequestHandler): if self.offline: return return_oai_internal_server_error() + disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \ + and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \ + and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__') + if GlobalConfig.get().openai_silent_trim: oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) else: oai_messages = self.request.json['messages'] - self.prompt = transform_messages_to_prompt(oai_messages) + self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling) self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + print(self.prompt) + request_valid, invalid_response = self.validate_request() if not request_valid: return invalid_response if not self.prompt: # TODO: format this as an openai error message - return Response('Invalid prompt'), 400 + return return_oai_invalid_request_error('Invalid prompt'), 400 # TODO: support Ooba backend self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) @@ -56,7 +63,7 @@ class OpenAIRequestHandler(RequestHandler): if invalid_oai_err_msg: return invalid_oai_err_msg - if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token): + if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)): try: # Gather the last message from the user and all preceding system messages msg_l = self.request.json['messages'].copy()