improve openai endpoint, exclude system tokens more places

This commit is contained in:
Cyberes 2023-09-25 09:32:23 -06:00
parent 6459a1c91b
commit bbe5d5a8fe
9 changed files with 77 additions and 56 deletions

View File

@ -24,6 +24,7 @@ config_default_vars = {
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
'http_host': None,
'admin_token': None,
'openai_epose_our_model': False,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -132,7 +132,7 @@ def sum_column(table_name, column_name):
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%'")
result = cursor.fetchone()
return result[0] if result else 0
finally:
@ -145,7 +145,7 @@ def get_distinct_ips_24h():
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,))
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND token NOT LIKE 'SYSTEM__%%'", (past_24_hours,))
result = cursor.fetchone()
return result[0] if result else 0
finally:

View File

@ -2,6 +2,7 @@ import math
import re
from collections import OrderedDict
from pathlib import Path
from typing import Union
import simplejson as json
from flask import make_response
@ -56,7 +57,7 @@ def indefinite_article(word):
return 'a'
def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True):
def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys=True):
response = make_response(json.dumps(json_dict, indent=indent, sort_keys=sort_keys))
response.headers['Content-Type'] = 'application/json; charset=utf-8'
response.headers['mimetype'] = 'application/json'

View File

@ -28,4 +28,5 @@ enable_streaming = True
openai_api_key = None
backend_request_timeout = 30
backend_generate_request_timeout = 95
admin_token = None
admin_token = None
openai_epose_our_model = False

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages'):
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
try:

View File

@ -1,22 +1,19 @@
from flask import jsonify, request
import traceback
import requests
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache, redis
from ..cache import ONE_MONTH_SECONDS, cache
from ..stats import server_start_time
from ... import opts
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model
import openai
@openai_bp.route('/models', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
def openai_list_models():
cache_key = 'openai_model_cache::' + request.url
cached_response = cache.get(cache_key)
if cached_response:
return cached_response
model, error = get_running_model()
if not model:
response = jsonify({
@ -26,41 +23,49 @@ def openai_list_models():
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
oai = fetch_openai_models()
r = {
"object": "list",
"data": [
{
"id": opts.running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": opts.running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
}
]
}
response = jsonify({**r, **oai}), 200
cache.set(cache_key, response, timeout=60)
r = []
if opts.openai_epose_our_model:
r = [{
"object": "list",
"data": [
{
"id": opts.running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": opts.running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
}
]
}]
response = jsonify_pretty(r + oai), 200
return response
@cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models():
return openai.Model.list()
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data']
except:
traceback.print_exc()
return []
else:
return []

View File

@ -69,8 +69,10 @@ class OpenAIRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
else:
return backend_response, backend_response_status_code
@ -129,7 +131,7 @@ def check_moderation_endpoint(prompt: str):
return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response):
def build_openai_response(prompt, response, model):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
@ -146,7 +148,7 @@ def build_openai_response(prompt, response):
"id": f"chatcmpl-{uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"model": opts.running_model if opts.openai_epose_our_model else model,
"choices": [{
"index": 0,
"message": {

View File

@ -22,13 +22,13 @@ class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
self.request = incoming_request
# routes need to validate it, here we just load it
# Routes need to validate it, here we just load it
if incoming_json:
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
else:
self.request_valid_json, self.request_json_body = validate_json(self.request)
if not self.request_valid_json:
raise Exception(f'Not valid JSON')
raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.')
self.start_time = time.time()
self.client_ip = self.get_client_ip()
@ -49,10 +49,12 @@ class RequestHandler:
return self.request.headers.get('X-Api-Key')
def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'):
return self.request.headers.get('cf-connecting-ip')
elif self.request.headers.get('x-forwarded-for'):
return self.request.headers.get('x-forwarded-for').split(',')[0]
if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'):
return self.request.headers.get('Cf-Connecting-Ip')
elif self.request.headers.get('X-Forwarded-For'):
return self.request.headers.get('X-Forwarded-For').split(',')[0]
else:
return self.request.remote_addr

View File

@ -17,6 +17,9 @@ from llm_server.routes.server_error import handle_server_error
# TODO: allow setting more custom ratelimits per-token
# TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: show requested model (not actual LLM backend model) in OpenAI responses
try:
import vllm
@ -90,6 +93,12 @@ opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_epose_our_model = config['openai_epose_our_model']
if opts.openai_epose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
if config['http_host']:
redis.set('http_host', config['http_host'])