more work on openai endpoint
This commit is contained in:
parent
9e6624e779
commit
d9bbcc42e6
|
@ -26,7 +26,12 @@ config_default_vars = {
|
|||
'admin_token': None,
|
||||
'openai_epose_our_model': False,
|
||||
'openai_force_no_hashes': True,
|
||||
'include_system_tokens_in_stats': True
|
||||
'include_system_tokens_in_stats': True,
|
||||
'openai_moderation_scan_last_n': 5,
|
||||
'openai_moderation_workers': 10,
|
||||
'openai_org_name': 'OpenAI',
|
||||
'openai_silent_trim': False,
|
||||
'openai_moderation_enabled': True
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -59,14 +59,29 @@ def is_valid_api_key(api_key):
|
|||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
conn.commit()
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def is_api_key_moderated(api_key):
|
||||
if not api_key:
|
||||
return opts.openai_moderation_enabled
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
print(bool(row[0]))
|
||||
if row is not None:
|
||||
return bool(row[0])
|
||||
return opts.openai_moderation_enabled
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
|
|
|
@ -2,7 +2,7 @@ from llm_server.llm import oobabooga, vllm
|
|||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def get_token_count(prompt):
|
||||
def get_token_count(prompt: str):
|
||||
backend_mode = redis.get('backend_mode', str)
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def check_moderation_endpoint(prompt: str):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||
}
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
||||
if response.status_code != 200:
|
||||
print(response.text)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
|
||||
offending_categories = []
|
||||
for k, v in response['results'][0]['categories'].items():
|
||||
if v:
|
||||
offending_categories.append(k)
|
||||
return response['results'][0]['flagged'], offending_categories
|
|
@ -0,0 +1,132 @@
|
|||
import concurrent.futures
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||
|
||||
|
||||
def build_openai_response(prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
# y = response.split('\n### ')
|
||||
# if len(y) > 1:
|
||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def generate_oai_string(length=24):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||
|
||||
|
||||
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
return len(tokenizer.encode(msg["content"]))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if total_tokens > context_token_limit:
|
||||
while True:
|
||||
while total_tokens + formatting_tokens > context_token_limit:
|
||||
# Calculate the index to start removing messages from
|
||||
remove_index = len(prompt) // 3
|
||||
|
||||
while remove_index < len(prompt):
|
||||
total_tokens -= token_counts[remove_index]
|
||||
prompt.pop(remove_index)
|
||||
token_counts.pop(remove_index)
|
||||
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"])
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
||||
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
else:
|
||||
break
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def transform_messages_to_prompt(oai_messages):
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in oai_messages:
|
||||
if not msg.get('content') or not msg.get('role'):
|
||||
return False
|
||||
if msg['role'] == 'system':
|
||||
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'user':
|
||||
prompt += f'### USER: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'assistant':
|
||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
# TODO: use logging
|
||||
traceback.print_exc()
|
||||
return ''
|
||||
|
||||
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
||||
prompt += '\n\n### RESPONSE: '
|
||||
return prompt
|
|
@ -5,10 +5,8 @@ import tiktoken
|
|||
|
||||
from llm_server import opts
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
|
@ -18,4 +16,5 @@ def tokenize(prompt: str) -> int:
|
|||
return j['length']
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print(prompt)
|
||||
return len(tokenizer.encode(prompt)) + 10
|
||||
|
|
|
@ -32,3 +32,8 @@ admin_token = None
|
|||
openai_expose_our_model = False
|
||||
openai_force_no_hashes = True
|
||||
include_system_tokens_in_stats = True
|
||||
openai_moderation_scan_last_n = 5
|
||||
openai_moderation_workers = 10
|
||||
openai_org_name = 'OpenAI'
|
||||
openai_silent_trim = False
|
||||
openai_moderation_enabled = True
|
||||
|
|
|
@ -6,22 +6,26 @@ from flask import Response, request
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def parse_token(input_token):
|
||||
password = None
|
||||
if input_token.startswith('Basic '):
|
||||
try:
|
||||
_, password = basicauth.decode(input_token)
|
||||
except:
|
||||
return False
|
||||
elif input_token.startswith('Bearer '):
|
||||
password = input_token.split('Bearer ', maxsplit=1)
|
||||
del password[0]
|
||||
password = ''.join(password)
|
||||
return password
|
||||
|
||||
|
||||
def check_auth(token):
|
||||
if not opts.admin_token:
|
||||
# The admin token is not set/enabled.
|
||||
# Default: deny all.
|
||||
return False
|
||||
password = None
|
||||
if token.startswith('Basic '):
|
||||
try:
|
||||
_, password = basicauth.decode(token)
|
||||
except:
|
||||
return False
|
||||
elif token.startswith('Bearer '):
|
||||
password = token.split('Bearer ', maxsplit=1)
|
||||
del password[0]
|
||||
password = ''.join(password)
|
||||
return password == opts.admin_token
|
||||
return parse_token(token) == opts.admin_token
|
||||
|
||||
|
||||
def authenticate():
|
||||
|
|
|
@ -8,7 +8,7 @@ from flask_caching import Cache
|
|||
from redis import Redis
|
||||
from redis.typing import FieldT, ExpiryT
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
ONE_MONTH_SECONDS = 2678000
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from flask import jsonify, request
|
|||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import is_valid_api_key
|
||||
from llm_server.routes.auth import parse_token
|
||||
|
||||
|
||||
def cache_control(seconds):
|
||||
|
@ -29,25 +30,34 @@ def cache_control(seconds):
|
|||
|
||||
|
||||
def require_api_key():
|
||||
if not opts.auth_required:
|
||||
return
|
||||
elif 'X-Api-Key' in request.headers:
|
||||
if is_valid_api_key(request.headers['X-Api-Key']):
|
||||
if 'X-Api-Key' in request.headers:
|
||||
api_key = request.headers['X-Api-Key']
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
if is_valid_api_key(api_key):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
elif 'Authorization' in request.headers:
|
||||
token = parse_token(request.headers['Authorization'])
|
||||
if token.startswith('SYSTEM__') or opts.auth_required:
|
||||
if is_valid_api_key(token):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
|
||||
else:
|
||||
try:
|
||||
# Handle websockets
|
||||
if request.json.get('X-API-KEY'):
|
||||
if is_valid_api_key(request.json.get('X-API-KEY')):
|
||||
api_key = request.json.get('X-API-KEY')
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
if is_valid_api_key(api_key):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
except:
|
||||
# TODO: remove this one we're sure this works as expected
|
||||
traceback.print_exc()
|
||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
return
|
||||
|
||||
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
||||
|
|
|
@ -1,26 +1,17 @@
|
|||
from flask import Blueprint, request
|
||||
from flask import Blueprint
|
||||
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import require_api_key
|
||||
from ..openai_request_handler import build_openai_response
|
||||
from ..request_handler import before_request
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
from ...helpers import auto_set_base_client_api
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
|
||||
|
||||
@openai_bp.before_request
|
||||
def before_oai_request():
|
||||
# TODO: unify with normal before_request()
|
||||
auto_set_base_client_api(request)
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
return 'The OpenAI-compatible backend is disabled.', 401
|
||||
return before_request()
|
||||
|
||||
|
||||
@openai_bp.errorhandler(500)
|
||||
|
@ -32,3 +23,4 @@ from .models import openai_list_models
|
|||
from .chat_completions import openai_chat_completions
|
||||
from .info import get_openai_info
|
||||
from .simulated import *
|
||||
from .completions import openai_completions
|
||||
|
|
|
@ -9,7 +9,8 @@ from . import openai_bp
|
|||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ...llm.openai.transform import build_openai_response, generate_oai_string
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
import time
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
from ...llm import get_token_count
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
@openai_bp.route('/completions', methods=['POST'])
|
||||
def openai_completions():
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
try:
|
||||
response, status_code = OobaRequestHandler(request).handle_request()
|
||||
if status_code != 200:
|
||||
return status_code
|
||||
output = response.json['results'][0]['text']
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'])
|
||||
response_tokens = get_token_count(output)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
return jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": None
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!')
|
||||
print(traceback.format_exc())
|
||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
|
@ -1,12 +1,12 @@
|
|||
from flask import Response
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import cache
|
||||
from ..cache import flask_cache
|
||||
from ... import opts
|
||||
|
||||
|
||||
@openai_bp.route('/prompt', methods=['GET'])
|
||||
@cache.cached(timeout=2678000, query_string=True)
|
||||
@flask_cache.cached(timeout=2678000, query_string=True)
|
||||
def get_openai_info():
|
||||
if opts.expose_openai_system_prompt:
|
||||
resp = Response(opts.openai_system_prompt)
|
||||
|
|
|
@ -4,7 +4,7 @@ import requests
|
|||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
||||
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...helpers import jsonify_pretty
|
||||
|
@ -12,7 +12,7 @@ from ...llm.info import get_running_model
|
|||
|
||||
|
||||
@openai_bp.route('/models', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
@flask_cache.cached(timeout=60, query_string=True)
|
||||
def openai_list_models():
|
||||
model, error = get_running_model()
|
||||
if not model:
|
||||
|
@ -59,7 +59,7 @@ def openai_list_models():
|
|||
return response
|
||||
|
||||
|
||||
@cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||
def fetch_openai_models():
|
||||
if opts.openai_api_key:
|
||||
try:
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, cache
|
||||
from ..openai_request_handler import generate_oai_string
|
||||
from ..cache import ONE_MONTH_SECONDS, flask_cache
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
from ..stats import server_start_time
|
||||
|
||||
|
||||
@openai_bp.route('/organizations', methods=['GET'])
|
||||
@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
|
||||
@flask_cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
|
||||
def openai_organizations():
|
||||
return jsonify({
|
||||
"object": "list",
|
||||
|
|
|
@ -1,27 +1,17 @@
|
|||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
import requests
|
||||
import tiktoken
|
||||
from flask import jsonify
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||
from llm_server.threads import add_moderation_task, get_results
|
||||
|
||||
|
||||
class OpenAIRequestHandler(RequestHandler):
|
||||
|
@ -32,39 +22,32 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
assert not self.used
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
|
||||
else:
|
||||
oai_messages = self.request.json['messages']
|
||||
|
||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
self.prompt = self.transform_messages_to_prompt()
|
||||
|
||||
if opts.openai_api_key:
|
||||
if opts.openai_api_key and is_api_key_moderated(self.token):
|
||||
try:
|
||||
# Gather the last message from the user and all preceeding system messages
|
||||
msg_l = self.request.json['messages'].copy()
|
||||
msg_l.reverse()
|
||||
msgs_to_check = []
|
||||
for msg in msg_l:
|
||||
if msg['role'] == 'system':
|
||||
msgs_to_check.append(msg['content'])
|
||||
elif msg['role'] == 'user':
|
||||
msgs_to_check.append(msg['content'])
|
||||
break
|
||||
tag = uuid4()
|
||||
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
|
||||
for i in range(num_to_check):
|
||||
add_moderation_task(msg_l[i]['content'], tag)
|
||||
|
||||
flagged = False
|
||||
flagged_categories = []
|
||||
# TODO: make this threaded
|
||||
for msg in msgs_to_check:
|
||||
flagged, categories = check_moderation_endpoint(msg)
|
||||
flagged_categories.extend(categories)
|
||||
if flagged:
|
||||
break
|
||||
flagged_categories = get_results(tag, num_to_check)
|
||||
|
||||
if flagged and len(flagged_categories):
|
||||
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
|
||||
if len(flagged_categories):
|
||||
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
|
||||
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
||||
self.prompt = self.transform_messages_to_prompt()
|
||||
# print(json.dumps(self.request.json['messages'], indent=4))
|
||||
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
||||
except Exception as e:
|
||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
|
@ -92,32 +75,8 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
|
||||
return build_openai_response(self.prompt, backend_response), 429
|
||||
|
||||
def transform_messages_to_prompt(self):
|
||||
# TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in self.request.json['messages']:
|
||||
if not msg.get('content') or not msg.get('role'):
|
||||
return False
|
||||
if msg['role'] == 'system':
|
||||
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'user':
|
||||
prompt += f'### USER: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'assistant':
|
||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
# TODO: use logging
|
||||
print(f'Failed to transform OpenAI to prompt:', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
return ''
|
||||
|
||||
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
||||
prompt += '\n\n### RESPONSE: '
|
||||
return prompt
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
print(msg)
|
||||
# return build_openai_response('', msg), 400
|
||||
return jsonify({
|
||||
"error": {
|
||||
|
@ -127,60 +86,3 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
|
||||
def check_moderation_endpoint(prompt: str):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||
}
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
|
||||
offending_categories = []
|
||||
for k, v in response['results'][0]['categories'].items():
|
||||
if v:
|
||||
offending_categories.append(k)
|
||||
return response['results'][0]['flagged'], offending_categories
|
||||
|
||||
|
||||
def build_openai_response(prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
# y = response.split('\n### ')
|
||||
# if len(y) > 1:
|
||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def generate_oai_string(length=24):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||
|
|
|
@ -10,6 +10,7 @@ from llm_server.database.database import log_prompt
|
|||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
|
@ -21,6 +22,7 @@ DEFAULT_PRIORITY = 9999
|
|||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
|
||||
# Routes need to validate it, here we just load it
|
||||
if incoming_json:
|
||||
|
@ -42,11 +44,12 @@ class RequestHandler:
|
|||
redis.set_dict('recent_prompters', recent_prompters)
|
||||
|
||||
def get_auth_token(self):
|
||||
websocket_key = self.request_json_body.get('X-API-KEY')
|
||||
if websocket_key:
|
||||
return websocket_key
|
||||
else:
|
||||
return self.request.headers.get('X-Api-Key')
|
||||
if self.request_json_body.get('X-API-KEY'):
|
||||
return self.request_json_body.get['X-API-KEY']
|
||||
elif self.request.headers.get('X-Api-Key'):
|
||||
return self.request.headers['X-Api-Key']
|
||||
elif self.request.headers['Authorization']:
|
||||
return parse_token(self.request.headers['Authorization'])
|
||||
|
||||
def get_client_ip(self):
|
||||
if self.request.headers.get('X-Connecting-IP'):
|
||||
|
|
|
@ -96,7 +96,6 @@ class SemaphoreCheckerThread(Thread):
|
|||
new_recent_prompters = {}
|
||||
|
||||
for ip, (timestamp, token) in recent_prompters.items():
|
||||
# TODO: validate token
|
||||
if token and token.startswith('SYSTEM__'):
|
||||
continue
|
||||
if current_time - timestamp <= 300:
|
||||
|
|
|
@ -17,6 +17,6 @@ def generate():
|
|||
try:
|
||||
return OobaRequestHandler(request).handle_request()
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
print(f'EXCEPTION on {request.url}!!!')
|
||||
print(traceback.format_exc())
|
||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
||||
|
|
|
@ -57,9 +57,14 @@ def stream(ws):
|
|||
err_msg = invalid_response[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'message_num': 0,
|
||||
'text': err_msg
|
||||
}))
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
|
||||
def background_task():
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
|
||||
|
@ -98,12 +103,18 @@ def stream(ws):
|
|||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
except:
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
|
@ -117,8 +128,7 @@ def stream(ws):
|
|||
elapsed_time = end_time - start_time
|
||||
|
||||
def background_task_success():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task_success)
|
||||
|
@ -141,9 +151,13 @@ def stream(ws):
|
|||
thread = threading.Thread(target=background_task_exception)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
except:
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
|
|
|
@ -4,7 +4,7 @@ from flask import jsonify, request
|
|||
|
||||
from . import bp
|
||||
from ..auth import requires_auth
|
||||
from ..cache import cache
|
||||
from ..cache import flask_cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
@ -21,7 +21,7 @@ def get_model():
|
|||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
cached_response = cache.get(cache_key)
|
||||
cached_response = flask_cache.get(cache_key)
|
||||
|
||||
if cached_response:
|
||||
return cached_response
|
||||
|
@ -38,7 +38,7 @@ def get_model():
|
|||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
flask_cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@ from flask import jsonify
|
|||
|
||||
from . import bp
|
||||
from .generate_stats import generate_stats
|
||||
from ..cache import cache
|
||||
from ..cache import flask_cache
|
||||
from ...helpers import jsonify_pretty
|
||||
|
||||
|
||||
@bp.route('/stats', methods=['GET'])
|
||||
@cache.cached(timeout=5, query_string=True)
|
||||
@flask_cache.cached(timeout=5, query_string=True)
|
||||
def get_stats():
|
||||
return jsonify_pretty(generate_stats())
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
|
||||
import redis as redis_redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
@ -70,3 +76,45 @@ def cache_stats():
|
|||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
redis_moderation = redis_redis.Redis()
|
||||
|
||||
|
||||
def start_moderation_workers(num_workers):
|
||||
for _ in range(num_workers):
|
||||
t = threading.Thread(target=moderation_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
||||
def moderation_worker():
|
||||
while True:
|
||||
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||
try:
|
||||
msg, tag = json.loads(result[1])
|
||||
_, categories = check_moderation_endpoint(msg)
|
||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||
except:
|
||||
print(result)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
def add_moderation_task(msg, tag):
|
||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||
|
||||
|
||||
def get_results(tag, num_tasks):
|
||||
tag = str(tag) # Required for comparison with Redis results.
|
||||
flagged_categories = set()
|
||||
num_results = 0
|
||||
while num_results < num_tasks:
|
||||
result = redis_moderation.blpop('queue:flagged_categories')
|
||||
result_tag, categories = json.loads(result[1])
|
||||
if result_tag == tag:
|
||||
if categories:
|
||||
for item in categories:
|
||||
flagged_categories.add(item)
|
||||
num_results += 1
|
||||
return list(flagged_categories)
|
||||
|
|
|
@ -10,6 +10,6 @@ To test on your local machine, run this command:
|
|||
docker run --shm-size 14g --gpus all \
|
||||
-v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
|
||||
-p 7000:7000 -p 8888:8888 \
|
||||
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 8192 --gpu-memory-utilization 1" \
|
||||
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 99999 --gpu-memory-utilization 1" \
|
||||
vllm-cloud
|
||||
```
|
|
@ -6,7 +6,7 @@ requests~=2.31.0
|
|||
tiktoken~=0.5.0
|
||||
gunicorn
|
||||
redis~=5.0.0
|
||||
gevent
|
||||
gevent~=23.9.0.post1
|
||||
async-timeout
|
||||
flask-sock
|
||||
uvicorn~=0.23.2
|
||||
|
@ -18,3 +18,5 @@ simplejson~=3.19.1
|
|||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
||||
openai~=0.28.0
|
||||
urllib3~=2.0.4
|
||||
rq~=1.15.1
|
38
server.py
38
server.py
|
@ -1,3 +1,10 @@
|
|||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
@ -18,14 +25,15 @@ from llm_server.routes.server_error import handle_server_error
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
|
||||
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
||||
# TODO: support turbo-instruct on openai endpoint
|
||||
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
||||
# TODO: validate system tokens before excluding them
|
||||
# TODO: make sure prompts are logged even when the user cancels generation
|
||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||
# TODO: use the current estimated wait time for ratelimit headers on openai
|
||||
# TODO: accept a header to specify if the openai endpoint should return sillytavern-formatted errors
|
||||
# TODO: allow setting concurrent gens per-backend
|
||||
# TODO: use first backend as default backend
|
||||
# TODO: allow disabling OpenAI moderation endpoint per-token
|
||||
|
||||
# TODO: allow setting specific simoltaneous IPs allowed per token
|
||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||
# TODO: unify logging thread in a function and use async/await instead
|
||||
# TODO: add more excluding to SYSTEM__ tokens
|
||||
|
@ -43,11 +51,12 @@ from llm_server import opts
|
|||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
@ -55,8 +64,8 @@ app = Flask(__name__)
|
|||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
|
@ -81,7 +90,7 @@ if config['mode'] not in ['oobabooga', 'vllm']:
|
|||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: this is a MESS
|
||||
# TODO: this is atrocious
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
|
@ -108,6 +117,11 @@ opts.admin_token = config['admin_token']
|
|||
opts.openai_expose_our_model = config['openai_epose_our_model']
|
||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
|
@ -134,6 +148,7 @@ else:
|
|||
|
||||
|
||||
def pre_fork(server):
|
||||
llm_server.llm.redis = RedisWrapper('local_llm')
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
|
@ -148,6 +163,7 @@ def pre_fork(server):
|
|||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
process_avg_gen_time_background_thread.daemon = True
|
||||
process_avg_gen_time_background_thread.start()
|
||||
|
@ -170,7 +186,7 @@ def pre_fork(server):
|
|||
@app.route('/')
|
||||
@app.route('/api')
|
||||
@app.route('/api/openai')
|
||||
@cache.cached(timeout=10)
|
||||
@flask_cache.cached(timeout=10)
|
||||
def home():
|
||||
stats = generate_stats()
|
||||
|
||||
|
@ -245,4 +261,6 @@ def before_app_request():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pre_fork(None)
|
||||
print('FLASK MODE - Startup complete!')
|
||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||
|
|
Reference in New Issue