improve openai endpoint, exclude system tokens more places
This commit is contained in:
parent
6459a1c91b
commit
bbe5d5a8fe
|
@ -24,6 +24,7 @@ config_default_vars = {
|
||||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
||||||
'http_host': None,
|
'http_host': None,
|
||||||
'admin_token': None,
|
'admin_token': None,
|
||||||
|
'openai_epose_our_model': False,
|
||||||
}
|
}
|
||||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ def sum_column(table_name, column_name):
|
||||||
conn = db_pool.connection()
|
conn = db_pool.connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
try:
|
try:
|
||||||
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
|
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%'")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
finally:
|
finally:
|
||||||
|
@ -145,7 +145,7 @@ def get_distinct_ips_24h():
|
||||||
conn = db_pool.connection()
|
conn = db_pool.connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
try:
|
try:
|
||||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,))
|
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND token NOT LIKE 'SYSTEM__%%'", (past_24_hours,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import math
|
||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import make_response
|
from flask import make_response
|
||||||
|
@ -56,7 +57,7 @@ def indefinite_article(word):
|
||||||
return 'a'
|
return 'a'
|
||||||
|
|
||||||
|
|
||||||
def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True):
|
def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys=True):
|
||||||
response = make_response(json.dumps(json_dict, indent=indent, sort_keys=sort_keys))
|
response = make_response(json.dumps(json_dict, indent=indent, sort_keys=sort_keys))
|
||||||
response.headers['Content-Type'] = 'application/json; charset=utf-8'
|
response.headers['Content-Type'] = 'application/json; charset=utf-8'
|
||||||
response.headers['mimetype'] = 'application/json'
|
response.headers['mimetype'] = 'application/json'
|
||||||
|
|
|
@ -28,4 +28,5 @@ enable_streaming = True
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
backend_request_timeout = 30
|
backend_request_timeout = 30
|
||||||
backend_generate_request_timeout = 95
|
backend_generate_request_timeout = 95
|
||||||
admin_token = None
|
admin_token = None
|
||||||
|
openai_epose_our_model = False
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
||||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||||
def openai_chat_completions():
|
def openai_chat_completions():
|
||||||
request_valid_json, request_json_body = validate_json(request)
|
request_valid_json, request_json_body = validate_json(request)
|
||||||
if not request_valid_json or not request_json_body.get('messages'):
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,22 +1,19 @@
|
||||||
from flask import jsonify, request
|
import traceback
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
from ..cache import ONE_MONTH_SECONDS, cache
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
from ...helpers import jsonify_pretty
|
||||||
from ...llm.info import get_running_model
|
from ...llm.info import get_running_model
|
||||||
import openai
|
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/models', methods=['GET'])
|
@openai_bp.route('/models', methods=['GET'])
|
||||||
@cache.cached(timeout=60, query_string=True)
|
@cache.cached(timeout=60, query_string=True)
|
||||||
def openai_list_models():
|
def openai_list_models():
|
||||||
cache_key = 'openai_model_cache::' + request.url
|
|
||||||
cached_response = cache.get(cache_key)
|
|
||||||
|
|
||||||
if cached_response:
|
|
||||||
return cached_response
|
|
||||||
|
|
||||||
model, error = get_running_model()
|
model, error = get_running_model()
|
||||||
if not model:
|
if not model:
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
|
@ -26,41 +23,49 @@ def openai_list_models():
|
||||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||||
else:
|
else:
|
||||||
oai = fetch_openai_models()
|
oai = fetch_openai_models()
|
||||||
r = {
|
r = []
|
||||||
"object": "list",
|
if opts.openai_epose_our_model:
|
||||||
"data": [
|
r = [{
|
||||||
{
|
"object": "list",
|
||||||
"id": opts.running_model,
|
"data": [
|
||||||
"object": "model",
|
{
|
||||||
"created": int(server_start_time.timestamp()),
|
"id": opts.running_model,
|
||||||
"owned_by": opts.llm_middleware_name,
|
"object": "model",
|
||||||
"permission": [
|
"created": int(server_start_time.timestamp()),
|
||||||
{
|
"owned_by": opts.llm_middleware_name,
|
||||||
"id": opts.running_model,
|
"permission": [
|
||||||
"object": "model_permission",
|
{
|
||||||
"created": int(server_start_time.timestamp()),
|
"id": opts.running_model,
|
||||||
"allow_create_engine": False,
|
"object": "model_permission",
|
||||||
"allow_sampling": False,
|
"created": int(server_start_time.timestamp()),
|
||||||
"allow_logprobs": False,
|
"allow_create_engine": False,
|
||||||
"allow_search_indices": False,
|
"allow_sampling": False,
|
||||||
"allow_view": True,
|
"allow_logprobs": False,
|
||||||
"allow_fine_tuning": False,
|
"allow_search_indices": False,
|
||||||
"organization": "*",
|
"allow_view": True,
|
||||||
"group": None,
|
"allow_fine_tuning": False,
|
||||||
"is_blocking": False
|
"organization": "*",
|
||||||
}
|
"group": None,
|
||||||
],
|
"is_blocking": False
|
||||||
"root": None,
|
}
|
||||||
"parent": None
|
],
|
||||||
}
|
"root": None,
|
||||||
]
|
"parent": None
|
||||||
}
|
}
|
||||||
response = jsonify({**r, **oai}), 200
|
]
|
||||||
cache.set(cache_key, response, timeout=60)
|
}]
|
||||||
|
response = jsonify_pretty(r + oai), 200
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@cache.memoize(timeout=ONE_MONTH_SECONDS)
|
@cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||||
def fetch_openai_models():
|
def fetch_openai_models():
|
||||||
return openai.Model.list()
|
if opts.openai_api_key:
|
||||||
|
try:
|
||||||
|
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
||||||
|
return response.json()['data']
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
|
@ -69,8 +69,10 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
|
|
||||||
|
model = self.request_json_body.get('model')
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
|
||||||
else:
|
else:
|
||||||
return backend_response, backend_response_status_code
|
return backend_response, backend_response_status_code
|
||||||
|
|
||||||
|
@ -129,7 +131,7 @@ def check_moderation_endpoint(prompt: str):
|
||||||
return response['results'][0]['flagged'], offending_categories
|
return response['results'][0]['flagged'], offending_categories
|
||||||
|
|
||||||
|
|
||||||
def build_openai_response(prompt, response):
|
def build_openai_response(prompt, response, model):
|
||||||
# Seperate the user's prompt from the context
|
# Seperate the user's prompt from the context
|
||||||
x = prompt.split('### USER:')
|
x = prompt.split('### USER:')
|
||||||
if len(x) > 1:
|
if len(x) > 1:
|
||||||
|
@ -146,7 +148,7 @@ def build_openai_response(prompt, response):
|
||||||
"id": f"chatcmpl-{uuid4()}",
|
"id": f"chatcmpl-{uuid4()}",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": opts.running_model,
|
"model": opts.running_model if opts.openai_epose_our_model else model,
|
||||||
"choices": [{
|
"choices": [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
|
|
|
@ -22,13 +22,13 @@ class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||||
self.request = incoming_request
|
self.request = incoming_request
|
||||||
|
|
||||||
# routes need to validate it, here we just load it
|
# Routes need to validate it, here we just load it
|
||||||
if incoming_json:
|
if incoming_json:
|
||||||
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
|
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
|
||||||
else:
|
else:
|
||||||
self.request_valid_json, self.request_json_body = validate_json(self.request)
|
self.request_valid_json, self.request_json_body = validate_json(self.request)
|
||||||
if not self.request_valid_json:
|
if not self.request_valid_json:
|
||||||
raise Exception(f'Not valid JSON')
|
raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.')
|
||||||
|
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.client_ip = self.get_client_ip()
|
self.client_ip = self.get_client_ip()
|
||||||
|
@ -49,10 +49,12 @@ class RequestHandler:
|
||||||
return self.request.headers.get('X-Api-Key')
|
return self.request.headers.get('X-Api-Key')
|
||||||
|
|
||||||
def get_client_ip(self):
|
def get_client_ip(self):
|
||||||
if self.request.headers.get('cf-connecting-ip'):
|
if self.request.headers.get('X-Connecting-IP'):
|
||||||
return self.request.headers.get('cf-connecting-ip')
|
return self.request.headers.get('X-Connecting-IP')
|
||||||
elif self.request.headers.get('x-forwarded-for'):
|
elif self.request.headers.get('Cf-Connecting-Ip'):
|
||||||
return self.request.headers.get('x-forwarded-for').split(',')[0]
|
return self.request.headers.get('Cf-Connecting-Ip')
|
||||||
|
elif self.request.headers.get('X-Forwarded-For'):
|
||||||
|
return self.request.headers.get('X-Forwarded-For').split(',')[0]
|
||||||
else:
|
else:
|
||||||
return self.request.remote_addr
|
return self.request.remote_addr
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,9 @@ from llm_server.routes.server_error import handle_server_error
|
||||||
|
|
||||||
# TODO: allow setting more custom ratelimits per-token
|
# TODO: allow setting more custom ratelimits per-token
|
||||||
# TODO: add more excluding to SYSTEM__ tokens
|
# TODO: add more excluding to SYSTEM__ tokens
|
||||||
|
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
||||||
|
# TODO: support turbo-instruct on openai endpoint
|
||||||
|
# TODO: show requested model (not actual LLM backend model) in OpenAI responses
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
|
@ -90,6 +93,12 @@ opts.enable_streaming = config['enable_streaming']
|
||||||
opts.openai_api_key = config['openai_api_key']
|
opts.openai_api_key = config['openai_api_key']
|
||||||
openai.api_key = opts.openai_api_key
|
openai.api_key = opts.openai_api_key
|
||||||
opts.admin_token = config['admin_token']
|
opts.admin_token = config['admin_token']
|
||||||
|
opts.openai_epose_our_model = config['openai_epose_our_model']
|
||||||
|
|
||||||
|
if opts.openai_epose_our_model and not opts.openai_api_key:
|
||||||
|
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if config['http_host']:
|
if config['http_host']:
|
||||||
redis.set('http_host', config['http_host'])
|
redis.set('http_host', config['http_host'])
|
||||||
|
|
Reference in New Issue