refactor, add Llm-Disable-Openai header

This commit is contained in:
Cyberes 2024-05-07 17:41:53 -06:00
parent 5bd1044fad
commit fe23a2282f
8 changed files with 30 additions and 22 deletions

View File

@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`.
- [ ] Make sure stats work when starting from an empty database
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
- [ ] Add test to verify the OpenAI endpoint works as expected
- [ ] Document the `Llm-Disable-Openai` header

View File

@ -26,7 +26,7 @@ def get_running_models():
Get all the models that are in the cluster.
:return:
"""
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
return [x for x in list(redis_running_models.keys())]
def is_valid_model(model_name: str) -> bool:

View File

@ -1,10 +1,7 @@
import logging
import pickle
import sys
import traceback
from typing import Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
@ -27,13 +24,6 @@ class RedisCustom(Redis):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
logger = logging.getLogger('redis')
logger.setLevel(logging.INFO)
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
sys.exit(1)
def _key(self, key):
return f"{self.prefix}:{key}"
@ -75,7 +65,7 @@ class RedisCustom(Redis):
# Delete prefix
del p[0]
k = ':'.join(p)
# keys.append(k)
keys.append(k)
return keys
def exists(self, *names: KeyT):

View File

@ -62,7 +62,7 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
def is_valid_api_key(api_key):
def is_valid_api_key(api_key: str):
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()

View File

@ -1,6 +1,10 @@
# Read-only global variables
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. "
"You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, "
"apologize and suggest the user seek help elsewhere.")
OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user."""
REDIS_STREAM_TIMEOUT = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -8,6 +8,7 @@ from typing import Dict, List
import tiktoken
from llm_server.config.global_config import GlobalConfig
from llm_server.globals import OPENAI_FORMATTING_PROMPT
from llm_server.llm import get_token_count
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
return prompt
def transform_messages_to_prompt(oai_messages):
def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False):
if not disable_openai_handling:
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}'
else:
prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}'
prompt = prompt + '\n\n'
try:
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}'
for msg in oai_messages:
if 'content' not in msg.keys() or 'role' not in msg.keys():
return False

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger

View File

@ -6,17 +6,18 @@ from typing import Tuple
from uuid import uuid4
import flask
from flask import Response, jsonify, make_response
from flask import Response, jsonify, make_response, request
from llm_server.cluster.backend import get_model_choices
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.database.database import is_api_key_moderated, is_valid_api_key
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger
from llm_server.routes.auth import parse_token
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
@ -33,21 +34,27 @@ class OpenAIRequestHandler(RequestHandler):
if self.offline:
return return_oai_internal_server_error()
disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \
and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \
and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__')
if GlobalConfig.get().openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
print(self.prompt)
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
return return_oai_invalid_request_error('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
@ -56,7 +63,7 @@ class OpenAIRequestHandler(RequestHandler):
if invalid_oai_err_msg:
return invalid_oai_err_msg
if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token):
if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)):
try:
# Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy()