more work on openai endpoint

This commit is contained in:
Cyberes 2023-09-26 22:09:11 -06:00
parent 9e6624e779
commit d9bbcc42e6
28 changed files with 441 additions and 211 deletions

View File

@ -26,7 +26,12 @@ config_default_vars = {
'admin_token': None,
'openai_epose_our_model': False,
'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True
'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5,
'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -59,14 +59,29 @@ def is_valid_api_key(api_key):
if row is not None:
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
conn.commit()
return False
finally:
cursor.close()
def is_api_key_moderated(api_key):
if not api_key:
return opts.openai_moderation_enabled
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
print(bool(row[0]))
if row is not None:
return bool(row[0])
return opts.openai_moderation_enabled
finally:
cursor.close()
def get_number_of_rows(table_name):
conn = db_pool.connection()
cursor = conn.cursor()

View File

@ -2,7 +2,7 @@ from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis
def get_token_count(prompt):
def get_token_count(prompt: str):
backend_mode = redis.get('backend_mode', str)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)

View File

View File

@ -0,0 +1,21 @@
import requests
from llm_server import opts
def check_moderation_endpoint(prompt: str):
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}",
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200:
print(response.text)
response.raise_for_status()
response = response.json()
offending_categories = []
for k, v in response['results'][0]['categories'].items():
if v:
offending_categories.append(k)
return response['results'][0]['flagged'], offending_categories

View File

@ -0,0 +1,132 @@
import concurrent.futures
import re
import secrets
import string
import time
import traceback
from typing import Dict, List
import tiktoken
from flask import jsonify
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
return jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
total_tokens -= token_counts[remove_index]
prompt.pop(remove_index)
token_counts.pop(remove_index)
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
except Exception as e:
# TODO: use logging
traceback.print_exc()
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt

View File

@ -5,10 +5,8 @@ import tiktoken
from llm_server import opts
tokenizer = tiktoken.get_encoding("cl100k_base")
def tokenize(prompt: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base")
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
@ -18,4 +16,5 @@ def tokenize(prompt: str) -> int:
return j['length']
except:
traceback.print_exc()
print(prompt)
return len(tokenizer.encode(prompt)) + 10

View File

@ -32,3 +32,8 @@ admin_token = None
openai_expose_our_model = False
openai_force_no_hashes = True
include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5
openai_moderation_workers = 10
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True

View File

@ -6,22 +6,26 @@ from flask import Response, request
from llm_server import opts
def parse_token(input_token):
password = None
if input_token.startswith('Basic '):
try:
_, password = basicauth.decode(input_token)
except:
return False
elif input_token.startswith('Bearer '):
password = input_token.split('Bearer ', maxsplit=1)
del password[0]
password = ''.join(password)
return password
def check_auth(token):
if not opts.admin_token:
# The admin token is not set/enabled.
# Default: deny all.
return False
password = None
if token.startswith('Basic '):
try:
_, password = basicauth.decode(token)
except:
return False
elif token.startswith('Bearer '):
password = token.split('Bearer ', maxsplit=1)
del password[0]
password = ''.join(password)
return password == opts.admin_token
return parse_token(token) == opts.admin_token
def authenticate():

View File

@ -8,7 +8,7 @@ from flask_caching import Cache
from redis import Redis
from redis.typing import FieldT, ExpiryT
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000

View File

@ -10,6 +10,7 @@ from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import is_valid_api_key
from llm_server.routes.auth import parse_token
def cache_control(seconds):
@ -29,25 +30,34 @@ def cache_control(seconds):
def require_api_key():
if not opts.auth_required:
return
elif 'X-Api-Key' in request.headers:
if is_valid_api_key(request.headers['X-Api-Key']):
if 'X-Api-Key' in request.headers:
api_key = request.headers['X-Api-Key']
if api_key.startswith('SYSTEM__') or opts.auth_required:
if is_valid_api_key(api_key):
return
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
elif 'Authorization' in request.headers:
token = parse_token(request.headers['Authorization'])
if token.startswith('SYSTEM__') or opts.auth_required:
if is_valid_api_key(token):
return
else:
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
else:
try:
# Handle websockets
if request.json.get('X-API-KEY'):
if is_valid_api_key(request.json.get('X-API-KEY')):
api_key = request.json.get('X-API-KEY')
if api_key.startswith('SYSTEM__') or opts.auth_required:
if is_valid_api_key(api_key):
return
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
except:
# TODO: remove this one we're sure this works as expected
traceback.print_exc()
return jsonify({'code': 401, 'message': 'API key required'}), 401
return
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):

View File

@ -1,26 +1,17 @@
from flask import Blueprint, request
from flask import Blueprint
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response
from ..request_handler import before_request
from ..server_error import handle_server_error
from ... import opts
from ...helpers import auto_set_base_client_api
openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request
def before_oai_request():
# TODO: unify with normal before_request()
auto_set_base_client_api(request)
if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response
return 'The OpenAI-compatible backend is disabled.', 401
return before_request()
@openai_bp.errorhandler(500)
@ -32,3 +23,4 @@ from .models import openai_list_models
from .chat_completions import openai_chat_completions
from .info import get_openai_info
from .simulated import *
from .completions import openai_completions

View File

@ -9,7 +9,8 @@ from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
from ..openai_request_handler import OpenAIRequestHandler
from ...llm.openai.transform import build_openai_response, generate_oai_string
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator

View File

@ -0,0 +1,60 @@
import time
import time
import traceback
from flask import jsonify, request
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...llm import get_token_count
from ...llm.openai.transform import generate_oai_string
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
def openai_completions():
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
response, status_code = OobaRequestHandler(request).handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR')
return jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": None
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
except Exception as e:
print(f'EXCEPTION on {request.url}!!!')
print(traceback.format_exc())
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500

View File

@ -1,12 +1,12 @@
from flask import Response
from . import openai_bp
from ..cache import cache
from ..cache import flask_cache
from ... import opts
@openai_bp.route('/prompt', methods=['GET'])
@cache.cached(timeout=2678000, query_string=True)
@flask_cache.cached(timeout=2678000, query_string=True)
def get_openai_info():
if opts.expose_openai_system_prompt:
resp = Response(opts.openai_system_prompt)

View File

@ -4,7 +4,7 @@ import requests
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache, redis
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
from ...helpers import jsonify_pretty
@ -12,7 +12,7 @@ from ...llm.info import get_running_model
@openai_bp.route('/models', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
model, error = get_running_model()
if not model:
@ -59,7 +59,7 @@ def openai_list_models():
return response
@cache.memoize(timeout=ONE_MONTH_SECONDS)
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models():
if opts.openai_api_key:
try:

View File

@ -1,13 +1,13 @@
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache
from ..openai_request_handler import generate_oai_string
from ..cache import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time
@openai_bp.route('/organizations', methods=['GET'])
@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
@flask_cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
def openai_organizations():
return jsonify({
"object": "list",

View File

@ -1,27 +1,17 @@
import json
import re
import secrets
import string
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
import requests
import tiktoken
from flask import jsonify
import llm_server
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.routes.cache import redis
from llm_server.database.database import is_api_key_moderated, log_prompt
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
tokenizer = tiktoken.get_encoding("cl100k_base")
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
from llm_server.threads import add_moderation_task, get_results
class OpenAIRequestHandler(RequestHandler):
@ -32,39 +22,32 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
self.prompt = self.transform_messages_to_prompt()
if opts.openai_api_key:
if opts.openai_api_key and is_api_key_moderated(self.token):
try:
# Gather the last message from the user and all preceeding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
msgs_to_check = []
for msg in msg_l:
if msg['role'] == 'system':
msgs_to_check.append(msg['content'])
elif msg['role'] == 'user':
msgs_to_check.append(msg['content'])
break
tag = uuid4()
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
for i in range(num_to_check):
add_moderation_task(msg_l[i]['content'], tag)
flagged = False
flagged_categories = []
# TODO: make this threaded
for msg in msgs_to_check:
flagged, categories = check_moderation_endpoint(msg)
flagged_categories.extend(categories)
if flagged:
break
flagged_categories = get_results(tag, num_to_check)
if flagged and len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
if len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = self.transform_messages_to_prompt()
# print(json.dumps(self.request.json['messages'], indent=4))
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
@ -92,32 +75,8 @@ class OpenAIRequestHandler(RequestHandler):
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
return build_openai_response(self.prompt, backend_response), 429
def transform_messages_to_prompt(self):
# TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in self.request.json['messages']:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
except Exception as e:
# TODO: use logging
print(f'Failed to transform OpenAI to prompt:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
print(msg)
# return build_openai_response('', msg), 400
return jsonify({
"error": {
@ -127,60 +86,3 @@ class OpenAIRequestHandler(RequestHandler):
"code": None
}
}), 400
def check_moderation_endpoint(prompt: str):
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}",
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
offending_categories = []
for k, v in response['results'][0]['categories'].items():
if v:
offending_categories.append(k)
return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
return jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))

View File

@ -10,6 +10,7 @@ from llm_server.database.database import log_prompt
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import require_api_key, validate_json
@ -21,6 +22,7 @@ DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it
if incoming_json:
@ -42,11 +44,12 @@ class RequestHandler:
redis.set_dict('recent_prompters', recent_prompters)
def get_auth_token(self):
websocket_key = self.request_json_body.get('X-API-KEY')
if websocket_key:
return websocket_key
else:
return self.request.headers.get('X-Api-Key')
if self.request_json_body.get('X-API-KEY'):
return self.request_json_body.get['X-API-KEY']
elif self.request.headers.get('X-Api-Key'):
return self.request.headers['X-Api-Key']
elif self.request.headers['Authorization']:
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):
if self.request.headers.get('X-Connecting-IP'):

View File

@ -96,7 +96,6 @@ class SemaphoreCheckerThread(Thread):
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
# TODO: validate token
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:

View File

@ -17,6 +17,6 @@ def generate():
try:
return OobaRequestHandler(request).handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
print(f'EXCEPTION on {request.url}!!!')
print(traceback.format_exc())
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500

View File

@ -57,9 +57,14 @@ def stream(ws):
err_msg = invalid_response[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'message_num': 0,
'text': err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close() # this is important if we encountered and error and exited early.
def background_task():
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
@ -98,12 +103,18 @@ def stream(ws):
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
return
message_num += 1
partial_response = b'' # Reset the partial response
@ -117,8 +128,7 @@ def stream(ws):
elapsed_time = end_time - start_time
def background_task_success():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_success)
@ -141,9 +151,13 @@ def stream(ws):
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
ws.close() # this is important if we encountered and error and exited early.
except:
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))

View File

@ -4,7 +4,7 @@ from flask import jsonify, request
from . import bp
from ..auth import requires_auth
from ..cache import cache
from ..cache import flask_cache
from ... import opts
from ...llm.info import get_running_model
@ -21,7 +21,7 @@ def get_model():
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
cached_response = cache.get(cache_key)
cached_response = flask_cache.get(cache_key)
if cached_response:
return cached_response
@ -38,7 +38,7 @@ def get_model():
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'timestamp': int(time.time())
}), 200
cache.set(cache_key, response, timeout=60)
flask_cache.set(cache_key, response, timeout=60)
return response

View File

@ -2,11 +2,11 @@ from flask import jsonify
from . import bp
from .generate_stats import generate_stats
from ..cache import cache
from ..cache import flask_cache
from ...helpers import jsonify_pretty
@bp.route('/stats', methods=['GET'])
@cache.cached(timeout=5, query_string=True)
@flask_cache.cached(timeout=5, query_string=True)
def get_stats():
return jsonify_pretty(generate_stats())

View File

@ -1,9 +1,15 @@
import json
import threading
import time
import traceback
from threading import Thread
import redis as redis_redis
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
@ -70,3 +76,45 @@ def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
flagged_categories = set()
num_results = 0
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
return list(flagged_categories)

View File

@ -10,6 +10,6 @@ To test on your local machine, run this command:
docker run --shm-size 14g --gpus all \
-v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
-p 7000:7000 -p 8888:8888 \
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 8192 --gpu-memory-utilization 1" \
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 99999 --gpu-memory-utilization 1" \
vllm-cloud
```

View File

@ -6,7 +6,7 @@ requests~=2.31.0
tiktoken~=0.5.0
gunicorn
redis~=5.0.0
gevent
gevent~=23.9.0.post1
async-timeout
flask-sock
uvicorn~=0.23.2
@ -18,3 +18,5 @@ simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
rq~=1.15.1

View File

@ -1,3 +1,10 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os
import re
import sys
@ -18,14 +25,15 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: make sure prompts are logged even when the user cancels generation
# TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: use the current estimated wait time for ratelimit headers on openai
# TODO: accept a header to specify if the openai endpoint should return sillytavern-formatted errors
# TODO: allow setting concurrent gens per-backend
# TODO: use first backend as default backend
# TODO: allow disabling OpenAI moderation endpoint per-token
# TODO: allow setting specific simoltaneous IPs allowed per token
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead
# TODO: add more excluding to SYSTEM__ tokens
@ -43,11 +51,12 @@ from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import MainBackgroundThread, cache_stats
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
script_path = os.path.dirname(os.path.realpath(__file__))
@ -55,8 +64,8 @@ app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
cache.init_app(app)
cache.clear() # clear redis cache
flask_cache.init_app(app)
flask_cache.clear()
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
@ -81,7 +90,7 @@ if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is a MESS
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
@ -108,6 +117,11 @@ opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
@ -134,6 +148,7 @@ else:
def pre_fork(server):
llm_server.llm.redis = RedisWrapper('local_llm')
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
@ -148,6 +163,7 @@ def pre_fork(server):
# Start background processes
start_workers(opts.concurrent_gens)
start_moderation_workers(opts.openai_moderation_workers)
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
@ -170,7 +186,7 @@ def pre_fork(server):
@app.route('/')
@app.route('/api')
@app.route('/api/openai')
@cache.cached(timeout=10)
@flask_cache.cached(timeout=10)
def home():
stats = generate_stats()
@ -245,4 +261,6 @@ def before_app_request():
if __name__ == "__main__":
pre_fork(None)
print('FLASK MODE - Startup complete!')
app.run(host='0.0.0.0', threaded=False, processes=15)