improve openai endpoint, exclude system tokens more places

This commit is contained in:
Cyberes 2023-09-25 09:32:23 -06:00
parent 6459a1c91b
commit bbe5d5a8fe
9 changed files with 77 additions and 56 deletions

View File

@ -24,6 +24,7 @@ config_default_vars = {
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""", 'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
'http_host': None, 'http_host': None,
'admin_token': None, 'admin_token': None,
'openai_epose_our_model': False,
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -132,7 +132,7 @@ def sum_column(table_name, column_name):
conn = db_pool.connection() conn = db_pool.connection()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%'")
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
finally: finally:
@ -145,7 +145,7 @@ def get_distinct_ips_24h():
conn = db_pool.connection() conn = db_pool.connection()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,)) cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND token NOT LIKE 'SYSTEM__%%'", (past_24_hours,))
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
finally: finally:

View File

@ -2,6 +2,7 @@ import math
import re import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Union
import simplejson as json import simplejson as json
from flask import make_response from flask import make_response
@ -56,7 +57,7 @@ def indefinite_article(word):
return 'a' return 'a'
def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True): def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys=True):
response = make_response(json.dumps(json_dict, indent=indent, sort_keys=sort_keys)) response = make_response(json.dumps(json_dict, indent=indent, sort_keys=sort_keys))
response.headers['Content-Type'] = 'application/json; charset=utf-8' response.headers['Content-Type'] = 'application/json; charset=utf-8'
response.headers['mimetype'] = 'application/json' response.headers['mimetype'] = 'application/json'

View File

@ -29,3 +29,4 @@ openai_api_key = None
backend_request_timeout = 30 backend_request_timeout = 30
backend_generate_request_timeout = 95 backend_generate_request_timeout = 95
admin_token = None admin_token = None
openai_epose_our_model = False

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
try: try:

View File

@ -1,22 +1,19 @@
from flask import jsonify, request import traceback
import requests
from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache, redis from ..cache import ONE_MONTH_SECONDS, cache
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model from ...llm.info import get_running_model
import openai
@openai_bp.route('/models', methods=['GET']) @openai_bp.route('/models', methods=['GET'])
@cache.cached(timeout=60, query_string=True) @cache.cached(timeout=60, query_string=True)
def openai_list_models(): def openai_list_models():
cache_key = 'openai_model_cache::' + request.url
cached_response = cache.get(cache_key)
if cached_response:
return cached_response
model, error = get_running_model() model, error = get_running_model()
if not model: if not model:
response = jsonify({ response = jsonify({
@ -26,41 +23,49 @@ def openai_list_models():
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
oai = fetch_openai_models() oai = fetch_openai_models()
r = { r = []
"object": "list", if opts.openai_epose_our_model:
"data": [ r = [{
{ "object": "list",
"id": opts.running_model, "data": [
"object": "model", {
"created": int(server_start_time.timestamp()), "id": opts.running_model,
"owned_by": opts.llm_middleware_name, "object": "model",
"permission": [ "created": int(server_start_time.timestamp()),
{ "owned_by": opts.llm_middleware_name,
"id": opts.running_model, "permission": [
"object": "model_permission", {
"created": int(server_start_time.timestamp()), "id": opts.running_model,
"allow_create_engine": False, "object": "model_permission",
"allow_sampling": False, "created": int(server_start_time.timestamp()),
"allow_logprobs": False, "allow_create_engine": False,
"allow_search_indices": False, "allow_sampling": False,
"allow_view": True, "allow_logprobs": False,
"allow_fine_tuning": False, "allow_search_indices": False,
"organization": "*", "allow_view": True,
"group": None, "allow_fine_tuning": False,
"is_blocking": False "organization": "*",
} "group": None,
], "is_blocking": False
"root": None, }
"parent": None ],
} "root": None,
] "parent": None
} }
response = jsonify({**r, **oai}), 200 ]
cache.set(cache_key, response, timeout=60) }]
response = jsonify_pretty(r + oai), 200
return response return response
@cache.memoize(timeout=ONE_MONTH_SECONDS) @cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models(): def fetch_openai_models():
return openai.Model.list() if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data']
except:
traceback.print_exc()
return []
else:
return []

View File

@ -69,8 +69,10 @@ class OpenAIRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success: if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
else: else:
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
@ -129,7 +131,7 @@ def check_moderation_endpoint(prompt: str):
return response['results'][0]['flagged'], offending_categories return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response): def build_openai_response(prompt, response, model):
# Seperate the user's prompt from the context # Seperate the user's prompt from the context
x = prompt.split('### USER:') x = prompt.split('### USER:')
if len(x) > 1: if len(x) > 1:
@ -146,7 +148,7 @@ def build_openai_response(prompt, response):
"id": f"chatcmpl-{uuid4()}", "id": f"chatcmpl-{uuid4()}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": opts.running_model, "model": opts.running_model if opts.openai_epose_our_model else model,
"choices": [{ "choices": [{
"index": 0, "index": 0,
"message": { "message": {

View File

@ -22,13 +22,13 @@ class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
self.request = incoming_request self.request = incoming_request
# routes need to validate it, here we just load it # Routes need to validate it, here we just load it
if incoming_json: if incoming_json:
self.request_valid_json, self.request_json_body = validate_json(incoming_json) self.request_valid_json, self.request_json_body = validate_json(incoming_json)
else: else:
self.request_valid_json, self.request_json_body = validate_json(self.request) self.request_valid_json, self.request_json_body = validate_json(self.request)
if not self.request_valid_json: if not self.request_valid_json:
raise Exception(f'Not valid JSON') raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.')
self.start_time = time.time() self.start_time = time.time()
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
@ -49,10 +49,12 @@ class RequestHandler:
return self.request.headers.get('X-Api-Key') return self.request.headers.get('X-Api-Key')
def get_client_ip(self): def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'): if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('cf-connecting-ip') return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('x-forwarded-for'): elif self.request.headers.get('Cf-Connecting-Ip'):
return self.request.headers.get('x-forwarded-for').split(',')[0] return self.request.headers.get('Cf-Connecting-Ip')
elif self.request.headers.get('X-Forwarded-For'):
return self.request.headers.get('X-Forwarded-For').split(',')[0]
else: else:
return self.request.remote_addr return self.request.remote_addr

View File

@ -17,6 +17,9 @@ from llm_server.routes.server_error import handle_server_error
# TODO: allow setting more custom ratelimits per-token # TODO: allow setting more custom ratelimits per-token
# TODO: add more excluding to SYSTEM__ tokens # TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: show requested model (not actual LLM backend model) in OpenAI responses
try: try:
import vllm import vllm
@ -90,6 +93,12 @@ opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key'] opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token'] opts.admin_token = config['admin_token']
opts.openai_epose_our_model = config['openai_epose_our_model']
if opts.openai_epose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
if config['http_host']: if config['http_host']:
redis.set('http_host', config['http_host']) redis.set('http_host', config['http_host'])