refactor, add Llm-Disable-Openai header
This commit is contained in:
parent
5bd1044fad
commit
fe23a2282f
|
@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`.
|
|||
- [ ] Make sure stats work when starting from an empty database
|
||||
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
|
||||
- [ ] Add test to verify the OpenAI endpoint works as expected
|
||||
- [ ] Document the `Llm-Disable-Openai` header
|
|
@ -26,7 +26,7 @@ def get_running_models():
|
|||
Get all the models that are in the cluster.
|
||||
:return:
|
||||
"""
|
||||
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
|
||||
return [x for x in list(redis_running_models.keys())]
|
||||
|
||||
|
||||
def is_valid_model(model_name: str) -> bool:
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
|
@ -27,13 +24,6 @@ class RedisCustom(Redis):
|
|||
super().__init__()
|
||||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
try:
|
||||
self.set('____', 1)
|
||||
except redis_pkg.exceptions.ConnectionError as e:
|
||||
logger = logging.getLogger('redis')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
|
||||
sys.exit(1)
|
||||
|
||||
def _key(self, key):
|
||||
return f"{self.prefix}:{key}"
|
||||
|
@ -75,7 +65,7 @@ class RedisCustom(Redis):
|
|||
# Delete prefix
|
||||
del p[0]
|
||||
k = ':'.join(p)
|
||||
# keys.append(k)
|
||||
keys.append(k)
|
||||
return keys
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
|
|
|
@ -62,7 +62,7 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
|||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
def is_valid_api_key(api_key: str):
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Read-only global variables
|
||||
|
||||
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||
DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. "
|
||||
"You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, "
|
||||
"apologize and suggest the user seek help elsewhere.")
|
||||
OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user."""
|
||||
|
||||
REDIS_STREAM_TIMEOUT = 25000
|
||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Dict, List
|
|||
import tiktoken
|
||||
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.globals import OPENAI_FORMATTING_PROMPT
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
|
@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
|
|||
return prompt
|
||||
|
||||
|
||||
def transform_messages_to_prompt(oai_messages):
|
||||
def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False):
|
||||
if not disable_openai_handling:
|
||||
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}'
|
||||
else:
|
||||
prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}'
|
||||
prompt = prompt + '\n\n'
|
||||
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}'
|
||||
for msg in oai_messages:
|
||||
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
||||
return False
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
|
|||
from ..queue import priority_queue
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from ...logging import create_logger
|
||||
|
||||
|
|
|
@ -6,17 +6,18 @@ from typing import Tuple
|
|||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
from flask import Response, jsonify, make_response
|
||||
from flask import Response, jsonify, make_response, request
|
||||
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.database.database import is_api_key_moderated, is_valid_api_key
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||
|
||||
|
@ -33,21 +34,27 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
if self.offline:
|
||||
return return_oai_internal_server_error()
|
||||
|
||||
disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \
|
||||
and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \
|
||||
and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__')
|
||||
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||
else:
|
||||
oai_messages = self.request.json['messages']
|
||||
|
||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||
self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling)
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
print(self.prompt)
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
if not self.prompt:
|
||||
# TODO: format this as an openai error message
|
||||
return Response('Invalid prompt'), 400
|
||||
return return_oai_invalid_request_error('Invalid prompt'), 400
|
||||
|
||||
# TODO: support Ooba backend
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
@ -56,7 +63,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token):
|
||||
if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)):
|
||||
try:
|
||||
# Gather the last message from the user and all preceding system messages
|
||||
msg_l = self.request.json['messages'].copy()
|
||||
|
|
Reference in New Issue