refactor, add Llm-Disable-Openai header
This commit is contained in:
parent
5bd1044fad
commit
fe23a2282f
|
@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`.
|
||||||
- [ ] Make sure stats work when starting from an empty database
|
- [ ] Make sure stats work when starting from an empty database
|
||||||
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
|
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
|
||||||
- [ ] Add test to verify the OpenAI endpoint works as expected
|
- [ ] Add test to verify the OpenAI endpoint works as expected
|
||||||
|
- [ ] Document the `Llm-Disable-Openai` header
|
|
@ -26,7 +26,7 @@ def get_running_models():
|
||||||
Get all the models that are in the cluster.
|
Get all the models that are in the cluster.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
|
return [x for x in list(redis_running_models.keys())]
|
||||||
|
|
||||||
|
|
||||||
def is_valid_model(model_name: str) -> bool:
|
def is_valid_model(model_name: str) -> bool:
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import logging
|
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import redis as redis_pkg
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
@ -27,13 +24,6 @@ class RedisCustom(Redis):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.redis = Redis(**kwargs)
|
self.redis = Redis(**kwargs)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
try:
|
|
||||||
self.set('____', 1)
|
|
||||||
except redis_pkg.exceptions.ConnectionError as e:
|
|
||||||
logger = logging.getLogger('redis')
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def _key(self, key):
|
def _key(self, key):
|
||||||
return f"{self.prefix}:{key}"
|
return f"{self.prefix}:{key}"
|
||||||
|
@ -75,7 +65,7 @@ class RedisCustom(Redis):
|
||||||
# Delete prefix
|
# Delete prefix
|
||||||
del p[0]
|
del p[0]
|
||||||
k = ':'.join(p)
|
k = ':'.join(p)
|
||||||
# keys.append(k)
|
keys.append(k)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
def exists(self, *names: KeyT):
|
def exists(self, *names: KeyT):
|
||||||
|
|
|
@ -62,7 +62,7 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
||||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||||
|
|
||||||
|
|
||||||
def is_valid_api_key(api_key):
|
def is_valid_api_key(api_key: str):
|
||||||
with CursorFromConnectionFromPool() as cursor:
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# Read-only global variables
|
# Read-only global variables
|
||||||
|
|
||||||
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. "
|
||||||
|
"You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, "
|
||||||
|
"apologize and suggest the user seek help elsewhere.")
|
||||||
|
OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user."""
|
||||||
|
|
||||||
REDIS_STREAM_TIMEOUT = 25000
|
REDIS_STREAM_TIMEOUT = 25000
|
||||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Dict, List
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from llm_server.config.global_config import GlobalConfig
|
from llm_server.config.global_config import GlobalConfig
|
||||||
|
from llm_server.globals import OPENAI_FORMATTING_PROMPT
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
|
|
||||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||||
|
@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def transform_messages_to_prompt(oai_messages):
|
def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False):
|
||||||
|
if not disable_openai_handling:
|
||||||
|
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}'
|
||||||
|
else:
|
||||||
|
prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}'
|
||||||
|
prompt = prompt + '\n\n'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}'
|
|
||||||
for msg in oai_messages:
|
for msg in oai_messages:
|
||||||
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ...config.global_config import GlobalConfig
|
from ...config.global_config import GlobalConfig
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from ...logging import create_logger
|
from ...logging import create_logger
|
||||||
|
|
||||||
|
|
|
@ -6,17 +6,18 @@ from typing import Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import Response, jsonify, make_response
|
from flask import Response, jsonify, make_response, request
|
||||||
|
|
||||||
from llm_server.cluster.backend import get_model_choices
|
from llm_server.cluster.backend import get_model_choices
|
||||||
from llm_server.config.global_config import GlobalConfig
|
from llm_server.config.global_config import GlobalConfig
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import is_api_key_moderated
|
from llm_server.database.database import is_api_key_moderated, is_valid_api_key
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
|
||||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from llm_server.logging import create_logger
|
from llm_server.logging import create_logger
|
||||||
|
from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||||
|
|
||||||
|
@ -33,21 +34,27 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
if self.offline:
|
if self.offline:
|
||||||
return return_oai_internal_server_error()
|
return return_oai_internal_server_error()
|
||||||
|
|
||||||
|
disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \
|
||||||
|
and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \
|
||||||
|
and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__')
|
||||||
|
|
||||||
if GlobalConfig.get().openai_silent_trim:
|
if GlobalConfig.get().openai_silent_trim:
|
||||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||||
else:
|
else:
|
||||||
oai_messages = self.request.json['messages']
|
oai_messages = self.request.json['messages']
|
||||||
|
|
||||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling)
|
||||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
|
|
||||||
|
print(self.prompt)
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
|
|
||||||
if not self.prompt:
|
if not self.prompt:
|
||||||
# TODO: format this as an openai error message
|
# TODO: format this as an openai error message
|
||||||
return Response('Invalid prompt'), 400
|
return return_oai_invalid_request_error('Invalid prompt'), 400
|
||||||
|
|
||||||
# TODO: support Ooba backend
|
# TODO: support Ooba backend
|
||||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
|
@ -56,7 +63,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
return invalid_oai_err_msg
|
return invalid_oai_err_msg
|
||||||
|
|
||||||
if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token):
|
if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)):
|
||||||
try:
|
try:
|
||||||
# Gather the last message from the user and all preceding system messages
|
# Gather the last message from the user and all preceding system messages
|
||||||
msg_l = self.request.json['messages'].copy()
|
msg_l = self.request.json['messages'].copy()
|
||||||
|
|
Reference in New Issue