refactor, add Llm-Disable-Openai header

This commit is contained in:
Cyberes 2024-05-07 17:41:53 -06:00
parent 5bd1044fad
commit fe23a2282f
8 changed files with 30 additions and 22 deletions

View File

@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`.
- [ ] Make sure stats work when starting from an empty database - [ ] Make sure stats work when starting from an empty database
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation. - [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
- [ ] Add test to verify the OpenAI endpoint works as expected - [ ] Add test to verify the OpenAI endpoint works as expected
- [ ] Document the `Llm-Disable-Openai` header

View File

@ -26,7 +26,7 @@ def get_running_models():
Get all the models that are in the cluster. Get all the models that are in the cluster.
:return: :return:
""" """
return [x.decode('utf-8') for x in list(redis_running_models.keys())] return [x for x in list(redis_running_models.keys())]
def is_valid_model(model_name: str) -> bool: def is_valid_model(model_name: str) -> bool:

View File

@ -1,10 +1,7 @@
import logging
import pickle import pickle
import sys
import traceback import traceback
from typing import Union from typing import Union
import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
@ -27,13 +24,6 @@ class RedisCustom(Redis):
super().__init__() super().__init__()
self.redis = Redis(**kwargs) self.redis = Redis(**kwargs)
self.prefix = prefix self.prefix = prefix
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
logger = logging.getLogger('redis')
logger.setLevel(logging.INFO)
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
sys.exit(1)
def _key(self, key): def _key(self, key):
return f"{self.prefix}:{key}" return f"{self.prefix}:{key}"
@ -75,7 +65,7 @@ class RedisCustom(Redis):
# Delete prefix # Delete prefix
del p[0] del p[0]
k = ':'.join(p) k = ':'.join(p)
# keys.append(k) keys.append(k)
return keys return keys
def exists(self, *names: KeyT): def exists(self, *names: KeyT):

View File

@ -62,7 +62,7 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
def is_valid_api_key(api_key): def is_valid_api_key(api_key: str):
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()

View File

@ -1,6 +1,10 @@
# Read-only global variables # Read-only global variables
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. "
"You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, "
"apologize and suggest the user seek help elsewhere.")
OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user."""
REDIS_STREAM_TIMEOUT = 25000 REDIS_STREAM_TIMEOUT = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -8,6 +8,7 @@ from typing import Dict, List
import tiktoken import tiktoken
from llm_server.config.global_config import GlobalConfig from llm_server.config.global_config import GlobalConfig
from llm_server.globals import OPENAI_FORMATTING_PROMPT
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
return prompt return prompt
def transform_messages_to_prompt(oai_messages): def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False):
if not disable_openai_handling:
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}'
else:
prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}'
prompt = prompt + '\n\n'
try: try:
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}'
for msg in oai_messages: for msg in oai_messages:
if 'content' not in msg.keys() or 'role' not in msg.keys(): if 'content' not in msg.keys() or 'role' not in msg.keys():
return False return False

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ...config.global_config import GlobalConfig from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger from ...logging import create_logger

View File

@ -6,17 +6,18 @@ from typing import Tuple
from uuid import uuid4 from uuid import uuid4
import flask import flask
from flask import Response, jsonify, make_response from flask import Response, jsonify, make_response, request
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.config.global_config import GlobalConfig from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated, is_valid_api_key
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.auth import parse_token
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
@ -33,21 +34,27 @@ class OpenAIRequestHandler(RequestHandler):
if self.offline: if self.offline:
return return_oai_internal_server_error() return return_oai_internal_server_error()
disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \
and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \
and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__')
if GlobalConfig.get().openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else: else:
oai_messages = self.request.json['messages'] oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages) self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
print(self.prompt)
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
if not self.prompt: if not self.prompt:
# TODO: format this as an openai error message # TODO: format this as an openai error message
return Response('Invalid prompt'), 400 return return_oai_invalid_request_error('Invalid prompt'), 400
# TODO: support Ooba backend # TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
@ -56,7 +63,7 @@ class OpenAIRequestHandler(RequestHandler):
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token): if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)):
try: try:
# Gather the last message from the user and all preceding system messages # Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy() msg_l = self.request.json['messages'].copy()