more work on openai endpoint
This commit is contained in:
parent
9e6624e779
commit
d9bbcc42e6
|
@ -26,7 +26,12 @@ config_default_vars = {
|
||||||
'admin_token': None,
|
'admin_token': None,
|
||||||
'openai_epose_our_model': False,
|
'openai_epose_our_model': False,
|
||||||
'openai_force_no_hashes': True,
|
'openai_force_no_hashes': True,
|
||||||
'include_system_tokens_in_stats': True
|
'include_system_tokens_in_stats': True,
|
||||||
|
'openai_moderation_scan_last_n': 5,
|
||||||
|
'openai_moderation_workers': 10,
|
||||||
|
'openai_org_name': 'OpenAI',
|
||||||
|
'openai_silent_trim': False,
|
||||||
|
'openai_moderation_enabled': True
|
||||||
}
|
}
|
||||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||||
|
|
||||||
|
|
|
@ -59,14 +59,29 @@ def is_valid_api_key(api_key):
|
||||||
if row is not None:
|
if row is not None:
|
||||||
token, uses, max_uses, expire, disabled = row
|
token, uses, max_uses, expire, disabled = row
|
||||||
disabled = bool(disabled)
|
disabled = bool(disabled)
|
||||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||||
return True
|
return True
|
||||||
conn.commit()
|
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
def is_api_key_moderated(api_key):
|
||||||
|
if not api_key:
|
||||||
|
return opts.openai_moderation_enabled
|
||||||
|
conn = db_pool.connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
print(bool(row[0]))
|
||||||
|
if row is not None:
|
||||||
|
return bool(row[0])
|
||||||
|
return opts.openai_moderation_enabled
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_rows(table_name):
|
def get_number_of_rows(table_name):
|
||||||
conn = db_pool.connection()
|
conn = db_pool.connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
|
@ -2,7 +2,7 @@ from llm_server.llm import oobabooga, vllm
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
def get_token_count(prompt):
|
def get_token_count(prompt: str):
|
||||||
backend_mode = redis.get('backend_mode', str)
|
backend_mode = redis.get('backend_mode', str)
|
||||||
if backend_mode == 'vllm':
|
if backend_mode == 'vllm':
|
||||||
return vllm.tokenize(prompt)
|
return vllm.tokenize(prompt)
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
def check_moderation_endpoint(prompt: str):
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||||
|
}
|
||||||
|
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(response.text)
|
||||||
|
response.raise_for_status()
|
||||||
|
response = response.json()
|
||||||
|
|
||||||
|
offending_categories = []
|
||||||
|
for k, v in response['results'][0]['categories'].items():
|
||||||
|
if v:
|
||||||
|
offending_categories.append(k)
|
||||||
|
return response['results'][0]['flagged'], offending_categories
|
|
@ -0,0 +1,132 @@
|
||||||
|
import concurrent.futures
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
import llm_server
|
||||||
|
from llm_server import opts
|
||||||
|
from llm_server.llm import get_token_count
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||||
|
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||||
|
|
||||||
|
|
||||||
|
def build_openai_response(prompt, response, model=None):
|
||||||
|
# Seperate the user's prompt from the context
|
||||||
|
x = prompt.split('### USER:')
|
||||||
|
if len(x) > 1:
|
||||||
|
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||||
|
|
||||||
|
# Make sure the bot doesn't put any other instructions in its response
|
||||||
|
# y = response.split('\n### ')
|
||||||
|
# if len(y) > 1:
|
||||||
|
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||||
|
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||||
|
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||||
|
|
||||||
|
# TODO: async/await
|
||||||
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||||
|
response_tokens = llm_server.llm.get_token_count(response)
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": running_model if opts.openai_expose_our_model else model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": response_tokens,
|
||||||
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def generate_oai_string(length=24):
|
||||||
|
alphabet = string.ascii_letters + string.digits
|
||||||
|
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
||||||
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def get_token_count_tiktoken_thread(msg):
|
||||||
|
return len(tokenizer.encode(msg["content"]))
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
||||||
|
|
||||||
|
total_tokens = sum(token_counts)
|
||||||
|
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
||||||
|
|
||||||
|
# If total tokens exceed the limit, start trimming
|
||||||
|
if total_tokens > context_token_limit:
|
||||||
|
while True:
|
||||||
|
while total_tokens + formatting_tokens > context_token_limit:
|
||||||
|
# Calculate the index to start removing messages from
|
||||||
|
remove_index = len(prompt) // 3
|
||||||
|
|
||||||
|
while remove_index < len(prompt):
|
||||||
|
total_tokens -= token_counts[remove_index]
|
||||||
|
prompt.pop(remove_index)
|
||||||
|
token_counts.pop(remove_index)
|
||||||
|
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_token_count_thread(msg):
|
||||||
|
return get_token_count(msg["content"])
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
|
|
||||||
|
total_tokens = sum(token_counts)
|
||||||
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
||||||
|
|
||||||
|
if total_tokens + formatting_tokens > context_token_limit:
|
||||||
|
# Start over, but this time calculate the token count using the backend
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def transform_messages_to_prompt(oai_messages):
|
||||||
|
try:
|
||||||
|
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||||
|
for msg in oai_messages:
|
||||||
|
if not msg.get('content') or not msg.get('role'):
|
||||||
|
return False
|
||||||
|
if msg['role'] == 'system':
|
||||||
|
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||||||
|
elif msg['role'] == 'user':
|
||||||
|
prompt += f'### USER: {msg["content"]}\n\n'
|
||||||
|
elif msg['role'] == 'assistant':
|
||||||
|
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: use logging
|
||||||
|
traceback.print_exc()
|
||||||
|
return ''
|
||||||
|
|
||||||
|
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
||||||
|
prompt += '\n\n### RESPONSE: '
|
||||||
|
return prompt
|
|
@ -5,10 +5,8 @@ import tiktoken
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize(prompt: str) -> int:
|
def tokenize(prompt: str) -> int:
|
||||||
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
if not prompt:
|
if not prompt:
|
||||||
# The tokenizers have issues when the prompt is None.
|
# The tokenizers have issues when the prompt is None.
|
||||||
return 0
|
return 0
|
||||||
|
@ -18,4 +16,5 @@ def tokenize(prompt: str) -> int:
|
||||||
return j['length']
|
return j['length']
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
print(prompt)
|
||||||
return len(tokenizer.encode(prompt)) + 10
|
return len(tokenizer.encode(prompt)) + 10
|
||||||
|
|
|
@ -32,3 +32,8 @@ admin_token = None
|
||||||
openai_expose_our_model = False
|
openai_expose_our_model = False
|
||||||
openai_force_no_hashes = True
|
openai_force_no_hashes = True
|
||||||
include_system_tokens_in_stats = True
|
include_system_tokens_in_stats = True
|
||||||
|
openai_moderation_scan_last_n = 5
|
||||||
|
openai_moderation_workers = 10
|
||||||
|
openai_org_name = 'OpenAI'
|
||||||
|
openai_silent_trim = False
|
||||||
|
openai_moderation_enabled = True
|
||||||
|
|
|
@ -6,22 +6,26 @@ from flask import Response, request
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_token(input_token):
|
||||||
|
password = None
|
||||||
|
if input_token.startswith('Basic '):
|
||||||
|
try:
|
||||||
|
_, password = basicauth.decode(input_token)
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
elif input_token.startswith('Bearer '):
|
||||||
|
password = input_token.split('Bearer ', maxsplit=1)
|
||||||
|
del password[0]
|
||||||
|
password = ''.join(password)
|
||||||
|
return password
|
||||||
|
|
||||||
|
|
||||||
def check_auth(token):
|
def check_auth(token):
|
||||||
if not opts.admin_token:
|
if not opts.admin_token:
|
||||||
# The admin token is not set/enabled.
|
# The admin token is not set/enabled.
|
||||||
# Default: deny all.
|
# Default: deny all.
|
||||||
return False
|
return False
|
||||||
password = None
|
return parse_token(token) == opts.admin_token
|
||||||
if token.startswith('Basic '):
|
|
||||||
try:
|
|
||||||
_, password = basicauth.decode(token)
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
elif token.startswith('Bearer '):
|
|
||||||
password = token.split('Bearer ', maxsplit=1)
|
|
||||||
del password[0]
|
|
||||||
password = ''.join(password)
|
|
||||||
return password == opts.admin_token
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate():
|
def authenticate():
|
||||||
|
|
|
@ -8,7 +8,7 @@ from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from redis.typing import FieldT, ExpiryT
|
from redis.typing import FieldT, ExpiryT
|
||||||
|
|
||||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||||
|
|
||||||
ONE_MONTH_SECONDS = 2678000
|
ONE_MONTH_SECONDS = 2678000
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.database import is_valid_api_key
|
from llm_server.database.database import is_valid_api_key
|
||||||
|
from llm_server.routes.auth import parse_token
|
||||||
|
|
||||||
|
|
||||||
def cache_control(seconds):
|
def cache_control(seconds):
|
||||||
|
@ -29,25 +30,34 @@ def cache_control(seconds):
|
||||||
|
|
||||||
|
|
||||||
def require_api_key():
|
def require_api_key():
|
||||||
if not opts.auth_required:
|
if 'X-Api-Key' in request.headers:
|
||||||
return
|
api_key = request.headers['X-Api-Key']
|
||||||
elif 'X-Api-Key' in request.headers:
|
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||||
if is_valid_api_key(request.headers['X-Api-Key']):
|
if is_valid_api_key(api_key):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
|
elif 'Authorization' in request.headers:
|
||||||
|
token = parse_token(request.headers['Authorization'])
|
||||||
|
if token.startswith('SYSTEM__') or opts.auth_required:
|
||||||
|
if is_valid_api_key(token):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Handle websockets
|
# Handle websockets
|
||||||
if request.json.get('X-API-KEY'):
|
if request.json.get('X-API-KEY'):
|
||||||
if is_valid_api_key(request.json.get('X-API-KEY')):
|
api_key = request.json.get('X-API-KEY')
|
||||||
|
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||||
|
if is_valid_api_key(api_key):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
except:
|
except:
|
||||||
# TODO: remove this one we're sure this works as expected
|
# TODO: remove this one we're sure this works as expected
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
return
|
||||||
|
|
||||||
|
|
||||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
||||||
|
|
|
@ -1,26 +1,17 @@
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint
|
||||||
|
|
||||||
from ..cache import redis
|
from ..request_handler import before_request
|
||||||
from ..helpers.client import format_sillytavern_err
|
|
||||||
from ..helpers.http import require_api_key
|
|
||||||
from ..openai_request_handler import build_openai_response
|
|
||||||
from ..server_error import handle_server_error
|
from ..server_error import handle_server_error
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...helpers import auto_set_base_client_api
|
|
||||||
|
|
||||||
openai_bp = Blueprint('openai/v1/', __name__)
|
openai_bp = Blueprint('openai/v1/', __name__)
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.before_request
|
@openai_bp.before_request
|
||||||
def before_oai_request():
|
def before_oai_request():
|
||||||
# TODO: unify with normal before_request()
|
|
||||||
auto_set_base_client_api(request)
|
|
||||||
if not opts.enable_openi_compatible_backend:
|
if not opts.enable_openi_compatible_backend:
|
||||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
return 'The OpenAI-compatible backend is disabled.', 401
|
||||||
if request.endpoint != 'v1.get_stats':
|
return before_request()
|
||||||
response = require_api_key()
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.errorhandler(500)
|
@openai_bp.errorhandler(500)
|
||||||
|
@ -32,3 +23,4 @@ from .models import openai_list_models
|
||||||
from .chat_completions import openai_chat_completions
|
from .chat_completions import openai_chat_completions
|
||||||
from .info import get_openai_info
|
from .info import get_openai_info
|
||||||
from .simulated import *
|
from .simulated import *
|
||||||
|
from .completions import openai_completions
|
||||||
|
|
|
@ -9,7 +9,8 @@ from . import openai_bp
|
||||||
from ..cache import redis
|
from ..cache import redis
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
from ..openai_request_handler import OpenAIRequestHandler
|
||||||
|
from ...llm.openai.transform import build_openai_response, generate_oai_string
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
import time
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from flask import jsonify, request
|
||||||
|
|
||||||
|
from . import openai_bp
|
||||||
|
from ..cache import redis
|
||||||
|
from ..helpers.client import format_sillytavern_err
|
||||||
|
from ..helpers.http import validate_json
|
||||||
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
|
from ... import opts
|
||||||
|
from ...llm import get_token_count
|
||||||
|
from ...llm.openai.transform import generate_oai_string
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add rate-limit headers?
|
||||||
|
|
||||||
|
@openai_bp.route('/completions', methods=['POST'])
|
||||||
|
def openai_completions():
|
||||||
|
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||||
|
|
||||||
|
request_valid_json, request_json_body = validate_json(request)
|
||||||
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response, status_code = OobaRequestHandler(request).handle_request()
|
||||||
|
if status_code != 200:
|
||||||
|
return status_code
|
||||||
|
output = response.json['results'][0]['text']
|
||||||
|
|
||||||
|
# TODO: async/await
|
||||||
|
prompt_tokens = get_token_count(request_json_body['prompt'])
|
||||||
|
response_tokens = get_token_count(output)
|
||||||
|
running_model = redis.get('running_model', str, 'ERROR')
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"id": f"cmpl-{generate_oai_string(30)}",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": output,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": response_tokens,
|
||||||
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f'EXCEPTION on {request.url}!!!')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
|
@ -1,12 +1,12 @@
|
||||||
from flask import Response
|
from flask import Response
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import cache
|
from ..cache import flask_cache
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/prompt', methods=['GET'])
|
@openai_bp.route('/prompt', methods=['GET'])
|
||||||
@cache.cached(timeout=2678000, query_string=True)
|
@flask_cache.cached(timeout=2678000, query_string=True)
|
||||||
def get_openai_info():
|
def get_openai_info():
|
||||||
if opts.expose_openai_system_prompt:
|
if opts.expose_openai_system_prompt:
|
||||||
resp = Response(opts.openai_system_prompt)
|
resp = Response(opts.openai_system_prompt)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import requests
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...helpers import jsonify_pretty
|
from ...helpers import jsonify_pretty
|
||||||
|
@ -12,7 +12,7 @@ from ...llm.info import get_running_model
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/models', methods=['GET'])
|
@openai_bp.route('/models', methods=['GET'])
|
||||||
@cache.cached(timeout=60, query_string=True)
|
@flask_cache.cached(timeout=60, query_string=True)
|
||||||
def openai_list_models():
|
def openai_list_models():
|
||||||
model, error = get_running_model()
|
model, error = get_running_model()
|
||||||
if not model:
|
if not model:
|
||||||
|
@ -59,7 +59,7 @@ def openai_list_models():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@cache.memoize(timeout=ONE_MONTH_SECONDS)
|
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||||
def fetch_openai_models():
|
def fetch_openai_models():
|
||||||
if opts.openai_api_key:
|
if opts.openai_api_key:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import ONE_MONTH_SECONDS, cache
|
from ..cache import ONE_MONTH_SECONDS, flask_cache
|
||||||
from ..openai_request_handler import generate_oai_string
|
from ...llm.openai.transform import generate_oai_string
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/organizations', methods=['GET'])
|
@openai_bp.route('/organizations', methods=['GET'])
|
||||||
@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
|
@flask_cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
|
||||||
def openai_organizations():
|
def openai_organizations():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|
|
@ -1,27 +1,17 @@
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import secrets
|
|
||||||
import string
|
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
import requests
|
|
||||||
import tiktoken
|
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
import llm_server
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.database import log_prompt
|
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
from llm_server.threads import add_moderation_task, get_results
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
|
||||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRequestHandler(RequestHandler):
|
class OpenAIRequestHandler(RequestHandler):
|
||||||
|
@ -32,39 +22,32 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
assert not self.used
|
assert not self.used
|
||||||
|
|
||||||
|
if opts.openai_silent_trim:
|
||||||
|
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
|
||||||
|
else:
|
||||||
|
oai_messages = self.request.json['messages']
|
||||||
|
|
||||||
|
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
|
|
||||||
self.prompt = self.transform_messages_to_prompt()
|
if opts.openai_api_key and is_api_key_moderated(self.token):
|
||||||
|
|
||||||
if opts.openai_api_key:
|
|
||||||
try:
|
try:
|
||||||
# Gather the last message from the user and all preceeding system messages
|
# Gather the last message from the user and all preceeding system messages
|
||||||
msg_l = self.request.json['messages'].copy()
|
msg_l = self.request.json['messages'].copy()
|
||||||
msg_l.reverse()
|
msg_l.reverse()
|
||||||
msgs_to_check = []
|
tag = uuid4()
|
||||||
for msg in msg_l:
|
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
|
||||||
if msg['role'] == 'system':
|
for i in range(num_to_check):
|
||||||
msgs_to_check.append(msg['content'])
|
add_moderation_task(msg_l[i]['content'], tag)
|
||||||
elif msg['role'] == 'user':
|
|
||||||
msgs_to_check.append(msg['content'])
|
|
||||||
break
|
|
||||||
|
|
||||||
flagged = False
|
flagged_categories = get_results(tag, num_to_check)
|
||||||
flagged_categories = []
|
|
||||||
# TODO: make this threaded
|
|
||||||
for msg in msgs_to_check:
|
|
||||||
flagged, categories = check_moderation_endpoint(msg)
|
|
||||||
flagged_categories.extend(categories)
|
|
||||||
if flagged:
|
|
||||||
break
|
|
||||||
|
|
||||||
if flagged and len(flagged_categories):
|
if len(flagged_categories):
|
||||||
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
|
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
|
||||||
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
||||||
self.prompt = self.transform_messages_to_prompt()
|
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
||||||
# print(json.dumps(self.request.json['messages'], indent=4))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
@ -92,32 +75,8 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
|
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
|
||||||
return build_openai_response(self.prompt, backend_response), 429
|
return build_openai_response(self.prompt, backend_response), 429
|
||||||
|
|
||||||
def transform_messages_to_prompt(self):
|
|
||||||
# TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response
|
|
||||||
try:
|
|
||||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
|
||||||
for msg in self.request.json['messages']:
|
|
||||||
if not msg.get('content') or not msg.get('role'):
|
|
||||||
return False
|
|
||||||
if msg['role'] == 'system':
|
|
||||||
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
|
||||||
elif msg['role'] == 'user':
|
|
||||||
prompt += f'### USER: {msg["content"]}\n\n'
|
|
||||||
elif msg['role'] == 'assistant':
|
|
||||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
# TODO: use logging
|
|
||||||
print(f'Failed to transform OpenAI to prompt:', f'{e.__class__.__name__}: {e}')
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return ''
|
|
||||||
|
|
||||||
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
|
||||||
prompt += '\n\n### RESPONSE: '
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||||
|
print(msg)
|
||||||
# return build_openai_response('', msg), 400
|
# return build_openai_response('', msg), 400
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -127,60 +86,3 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
"code": None
|
"code": None
|
||||||
}
|
}
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
|
|
||||||
def check_moderation_endpoint(prompt: str):
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
|
||||||
}
|
|
||||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
|
|
||||||
offending_categories = []
|
|
||||||
for k, v in response['results'][0]['categories'].items():
|
|
||||||
if v:
|
|
||||||
offending_categories.append(k)
|
|
||||||
return response['results'][0]['flagged'], offending_categories
|
|
||||||
|
|
||||||
|
|
||||||
def build_openai_response(prompt, response, model=None):
|
|
||||||
# Seperate the user's prompt from the context
|
|
||||||
x = prompt.split('### USER:')
|
|
||||||
if len(x) > 1:
|
|
||||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
|
||||||
|
|
||||||
# Make sure the bot doesn't put any other instructions in its response
|
|
||||||
# y = response.split('\n### ')
|
|
||||||
# if len(y) > 1:
|
|
||||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
|
||||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
|
||||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
|
||||||
|
|
||||||
# TODO: async/await
|
|
||||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
|
||||||
response_tokens = llm_server.llm.get_token_count(response)
|
|
||||||
running_model = redis.get('running_model', str, 'ERROR')
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": running_model if opts.openai_expose_our_model else model,
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response,
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": response_tokens,
|
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def generate_oai_string(length=24):
|
|
||||||
alphabet = string.ascii_letters + string.digits
|
|
||||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llm_server.database.database import log_prompt
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
|
from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
|
@ -21,6 +22,7 @@ DEFAULT_PRIORITY = 9999
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||||
self.request = incoming_request
|
self.request = incoming_request
|
||||||
|
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||||
|
|
||||||
# Routes need to validate it, here we just load it
|
# Routes need to validate it, here we just load it
|
||||||
if incoming_json:
|
if incoming_json:
|
||||||
|
@ -42,11 +44,12 @@ class RequestHandler:
|
||||||
redis.set_dict('recent_prompters', recent_prompters)
|
redis.set_dict('recent_prompters', recent_prompters)
|
||||||
|
|
||||||
def get_auth_token(self):
|
def get_auth_token(self):
|
||||||
websocket_key = self.request_json_body.get('X-API-KEY')
|
if self.request_json_body.get('X-API-KEY'):
|
||||||
if websocket_key:
|
return self.request_json_body.get['X-API-KEY']
|
||||||
return websocket_key
|
elif self.request.headers.get('X-Api-Key'):
|
||||||
else:
|
return self.request.headers['X-Api-Key']
|
||||||
return self.request.headers.get('X-Api-Key')
|
elif self.request.headers['Authorization']:
|
||||||
|
return parse_token(self.request.headers['Authorization'])
|
||||||
|
|
||||||
def get_client_ip(self):
|
def get_client_ip(self):
|
||||||
if self.request.headers.get('X-Connecting-IP'):
|
if self.request.headers.get('X-Connecting-IP'):
|
||||||
|
|
|
@ -96,7 +96,6 @@ class SemaphoreCheckerThread(Thread):
|
||||||
new_recent_prompters = {}
|
new_recent_prompters = {}
|
||||||
|
|
||||||
for ip, (timestamp, token) in recent_prompters.items():
|
for ip, (timestamp, token) in recent_prompters.items():
|
||||||
# TODO: validate token
|
|
||||||
if token and token.startswith('SYSTEM__'):
|
if token and token.startswith('SYSTEM__'):
|
||||||
continue
|
continue
|
||||||
if current_time - timestamp <= 300:
|
if current_time - timestamp <= 300:
|
||||||
|
|
|
@ -17,6 +17,6 @@ def generate():
|
||||||
try:
|
try:
|
||||||
return OobaRequestHandler(request).handle_request()
|
return OobaRequestHandler(request).handle_request()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
print(f'EXCEPTION on {request.url}!!!')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
||||||
|
|
|
@ -57,9 +57,14 @@ def stream(ws):
|
||||||
err_msg = invalid_response[0].json['results'][0]['text']
|
err_msg = invalid_response[0].json['results'][0]['text']
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': 0,
|
||||||
'text': err_msg
|
'text': err_msg
|
||||||
}))
|
}))
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': 1
|
||||||
|
}))
|
||||||
|
ws.close() # this is important if we encountered and error and exited early.
|
||||||
|
|
||||||
def background_task():
|
def background_task():
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
|
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
|
||||||
|
@ -98,12 +103,18 @@ def stream(ws):
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# ????
|
# ????
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': message_num,
|
||||||
'text': new
|
'text': new
|
||||||
}))
|
}))
|
||||||
|
except:
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
|
return
|
||||||
|
|
||||||
message_num += 1
|
message_num += 1
|
||||||
partial_response = b'' # Reset the partial response
|
partial_response = b'' # Reset the partial response
|
||||||
|
|
||||||
|
@ -117,8 +128,7 @@ def stream(ws):
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
def background_task_success():
|
def background_task_success():
|
||||||
generated_tokens = tokenize(generated_text)
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
|
||||||
|
|
||||||
# TODO: use async/await instead of threads
|
# TODO: use async/await instead of threads
|
||||||
thread = threading.Thread(target=background_task_success)
|
thread = threading.Thread(target=background_task_success)
|
||||||
|
@ -141,9 +151,13 @@ def stream(ws):
|
||||||
thread = threading.Thread(target=background_task_exception)
|
thread = threading.Thread(target=background_task_exception)
|
||||||
thread.start()
|
thread.start()
|
||||||
thread.join()
|
thread.join()
|
||||||
|
try:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
'message_num': message_num
|
'message_num': message_num
|
||||||
}))
|
}))
|
||||||
ws.close() # this is important if we encountered and error and exited early.
|
ws.close() # this is important if we encountered and error and exited early.
|
||||||
|
except:
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
|
|
|
@ -4,7 +4,7 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..auth import requires_auth
|
from ..auth import requires_auth
|
||||||
from ..cache import cache
|
from ..cache import flask_cache
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...llm.info import get_running_model
|
from ...llm.info import get_running_model
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def get_model():
|
||||||
# We will manage caching ourself since we don't want to cache
|
# We will manage caching ourself since we don't want to cache
|
||||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||||
cache_key = 'model_cache::' + request.url
|
cache_key = 'model_cache::' + request.url
|
||||||
cached_response = cache.get(cache_key)
|
cached_response = flask_cache.get(cache_key)
|
||||||
|
|
||||||
if cached_response:
|
if cached_response:
|
||||||
return cached_response
|
return cached_response
|
||||||
|
@ -38,7 +38,7 @@ def get_model():
|
||||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||||
'timestamp': int(time.time())
|
'timestamp': int(time.time())
|
||||||
}), 200
|
}), 200
|
||||||
cache.set(cache_key, response, timeout=60)
|
flask_cache.set(cache_key, response, timeout=60)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,11 @@ from flask import jsonify
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from .generate_stats import generate_stats
|
from .generate_stats import generate_stats
|
||||||
from ..cache import cache
|
from ..cache import flask_cache
|
||||||
from ...helpers import jsonify_pretty
|
from ...helpers import jsonify_pretty
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/stats', methods=['GET'])
|
@bp.route('/stats', methods=['GET'])
|
||||||
@cache.cached(timeout=5, query_string=True)
|
@flask_cache.cached(timeout=5, query_string=True)
|
||||||
def get_stats():
|
def get_stats():
|
||||||
return jsonify_pretty(generate_stats())
|
return jsonify_pretty(generate_stats())
|
||||||
|
|
|
@ -1,9 +1,15 @@
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import redis as redis_redis
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.database import weighted_average_column_for_model
|
from llm_server.database.database import weighted_average_column_for_model
|
||||||
from llm_server.llm.info import get_running_model
|
from llm_server.llm.info import get_running_model
|
||||||
|
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
|
||||||
|
@ -70,3 +76,45 @@ def cache_stats():
|
||||||
while True:
|
while True:
|
||||||
generate_stats(regen=True)
|
generate_stats(regen=True)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
redis_moderation = redis_redis.Redis()
|
||||||
|
|
||||||
|
|
||||||
|
def start_moderation_workers(num_workers):
|
||||||
|
for _ in range(num_workers):
|
||||||
|
t = threading.Thread(target=moderation_worker)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def moderation_worker():
|
||||||
|
while True:
|
||||||
|
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||||
|
try:
|
||||||
|
msg, tag = json.loads(result[1])
|
||||||
|
_, categories = check_moderation_endpoint(msg)
|
||||||
|
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||||
|
except:
|
||||||
|
print(result)
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def add_moderation_task(msg, tag):
|
||||||
|
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_results(tag, num_tasks):
|
||||||
|
tag = str(tag) # Required for comparison with Redis results.
|
||||||
|
flagged_categories = set()
|
||||||
|
num_results = 0
|
||||||
|
while num_results < num_tasks:
|
||||||
|
result = redis_moderation.blpop('queue:flagged_categories')
|
||||||
|
result_tag, categories = json.loads(result[1])
|
||||||
|
if result_tag == tag:
|
||||||
|
if categories:
|
||||||
|
for item in categories:
|
||||||
|
flagged_categories.add(item)
|
||||||
|
num_results += 1
|
||||||
|
return list(flagged_categories)
|
||||||
|
|
|
@ -10,6 +10,6 @@ To test on your local machine, run this command:
|
||||||
docker run --shm-size 14g --gpus all \
|
docker run --shm-size 14g --gpus all \
|
||||||
-v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
|
-v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
|
||||||
-p 7000:7000 -p 8888:8888 \
|
-p 7000:7000 -p 8888:8888 \
|
||||||
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 8192 --gpu-memory-utilization 1" \
|
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 99999 --gpu-memory-utilization 1" \
|
||||||
vllm-cloud
|
vllm-cloud
|
||||||
```
|
```
|
|
@ -6,7 +6,7 @@ requests~=2.31.0
|
||||||
tiktoken~=0.5.0
|
tiktoken~=0.5.0
|
||||||
gunicorn
|
gunicorn
|
||||||
redis~=5.0.0
|
redis~=5.0.0
|
||||||
gevent
|
gevent~=23.9.0.post1
|
||||||
async-timeout
|
async-timeout
|
||||||
flask-sock
|
flask-sock
|
||||||
uvicorn~=0.23.2
|
uvicorn~=0.23.2
|
||||||
|
@ -18,3 +18,5 @@ simplejson~=3.19.1
|
||||||
websockets~=11.0.3
|
websockets~=11.0.3
|
||||||
basicauth~=1.0.0
|
basicauth~=1.0.0
|
||||||
openai~=0.28.0
|
openai~=0.28.0
|
||||||
|
urllib3~=2.0.4
|
||||||
|
rq~=1.15.1
|
38
server.py
38
server.py
|
@ -1,3 +1,10 @@
|
||||||
|
try:
|
||||||
|
import gevent.monkey
|
||||||
|
|
||||||
|
gevent.monkey.patch_all()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
@ -18,14 +25,15 @@ from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.stream import init_socketio
|
from llm_server.stream import init_socketio
|
||||||
|
|
||||||
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
|
||||||
# TODO: support turbo-instruct on openai endpoint
|
|
||||||
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
|
||||||
# TODO: validate system tokens before excluding them
|
|
||||||
# TODO: make sure prompts are logged even when the user cancels generation
|
# TODO: make sure prompts are logged even when the user cancels generation
|
||||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||||
# TODO: use the current estimated wait time for ratelimit headers on openai
|
# TODO: use the current estimated wait time for ratelimit headers on openai
|
||||||
|
# TODO: accept a header to specify if the openai endpoint should return sillytavern-formatted errors
|
||||||
|
# TODO: allow setting concurrent gens per-backend
|
||||||
|
# TODO: use first backend as default backend
|
||||||
|
# TODO: allow disabling OpenAI moderation endpoint per-token
|
||||||
|
|
||||||
|
# TODO: allow setting specific simoltaneous IPs allowed per token
|
||||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||||
# TODO: unify logging thread in a function and use async/await instead
|
# TODO: unify logging thread in a function and use async/await instead
|
||||||
# TODO: add more excluding to SYSTEM__ tokens
|
# TODO: add more excluding to SYSTEM__ tokens
|
||||||
|
@ -43,11 +51,12 @@ from llm_server import opts
|
||||||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||||
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.routes.cache import cache, redis
|
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||||
|
from llm_server.llm import redis
|
||||||
from llm_server.routes.queue import start_workers
|
from llm_server.routes.queue import start_workers
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.threads import MainBackgroundThread, cache_stats
|
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
@ -55,8 +64,8 @@ app = Flask(__name__)
|
||||||
init_socketio(app)
|
init_socketio(app)
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
cache.init_app(app)
|
flask_cache.init_app(app)
|
||||||
cache.clear() # clear redis cache
|
flask_cache.clear()
|
||||||
|
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
if config_path_environ:
|
if config_path_environ:
|
||||||
|
@ -81,7 +90,7 @@ if config['mode'] not in ['oobabooga', 'vllm']:
|
||||||
print('Unknown mode:', config['mode'])
|
print('Unknown mode:', config['mode'])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# TODO: this is a MESS
|
# TODO: this is atrocious
|
||||||
opts.mode = config['mode']
|
opts.mode = config['mode']
|
||||||
opts.auth_required = config['auth_required']
|
opts.auth_required = config['auth_required']
|
||||||
opts.log_prompts = config['log_prompts']
|
opts.log_prompts = config['log_prompts']
|
||||||
|
@ -108,6 +117,11 @@ opts.admin_token = config['admin_token']
|
||||||
opts.openai_expose_our_model = config['openai_epose_our_model']
|
opts.openai_expose_our_model = config['openai_epose_our_model']
|
||||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||||
|
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||||
|
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||||
|
opts.openai_org_name = config['openai_org_name']
|
||||||
|
opts.openai_silent_trim = config['openai_silent_trim']
|
||||||
|
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||||
|
|
||||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||||
|
@ -134,6 +148,7 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def pre_fork(server):
|
def pre_fork(server):
|
||||||
|
llm_server.llm.redis = RedisWrapper('local_llm')
|
||||||
flushed_keys = redis.flush()
|
flushed_keys = redis.flush()
|
||||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||||
|
|
||||||
|
@ -148,6 +163,7 @@ def pre_fork(server):
|
||||||
|
|
||||||
# Start background processes
|
# Start background processes
|
||||||
start_workers(opts.concurrent_gens)
|
start_workers(opts.concurrent_gens)
|
||||||
|
start_moderation_workers(opts.openai_moderation_workers)
|
||||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||||
process_avg_gen_time_background_thread.daemon = True
|
process_avg_gen_time_background_thread.daemon = True
|
||||||
process_avg_gen_time_background_thread.start()
|
process_avg_gen_time_background_thread.start()
|
||||||
|
@ -170,7 +186,7 @@ def pre_fork(server):
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@app.route('/api')
|
@app.route('/api')
|
||||||
@app.route('/api/openai')
|
@app.route('/api/openai')
|
||||||
@cache.cached(timeout=10)
|
@flask_cache.cached(timeout=10)
|
||||||
def home():
|
def home():
|
||||||
stats = generate_stats()
|
stats = generate_stats()
|
||||||
|
|
||||||
|
@ -245,4 +261,6 @@ def before_app_request():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
pre_fork(None)
|
||||||
|
print('FLASK MODE - Startup complete!')
|
||||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||||
|
|
Reference in New Issue