Compare commits
7 Commits
Author | SHA1 | Date |
---|---|---|
Cyberes | 2ab2e6eed1 | |
Cyberes | 20366fbd08 | |
Cyberes | fe23a2282f | |
Cyberes | 5bd1044fad | |
Cyberes | fd09c783d3 | |
Cyberes | ee9a0d4858 | |
Cyberes | ff82add09e |
|
@ -111,3 +111,4 @@ Then, update the VLLM version in `requirements.txt`.
|
|||
- [ ] Make sure stats work when starting from an empty database
|
||||
- [ ] Make sure we're correctly canceling requests when the client cancels. The blocking endpoints can't detect when a client cancels generation.
|
||||
- [ ] Add test to verify the OpenAI endpoint works as expected
|
||||
- [ ] Document the `Llm-Disable-Openai` header
|
18
daemon.py
18
daemon.py
|
@ -3,14 +3,17 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import Database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.logging import create_logger, logging_info, init_logging
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.threader import start_background
|
||||
|
@ -20,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH")
|
|||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
config_path = resolve_path(script_path, 'config', 'config.yml')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
||||
|
@ -28,7 +31,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: have this be set by either the arg or a config value
|
||||
if args.debug:
|
||||
logging_info.level = logging.DEBUG
|
||||
|
||||
|
@ -40,19 +42,23 @@ if __name__ == "__main__":
|
|||
Redis().flushall()
|
||||
logger.info('Flushed Redis.')
|
||||
|
||||
success, config, msg = load_config(config_path)
|
||||
success, msg = load_config(config_path)
|
||||
if not success:
|
||||
logger.info(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||
create_db()
|
||||
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
cluster_config.load()
|
||||
|
||||
logger.info('Loading backend stats...')
|
||||
generate_stats(regen=True)
|
||||
|
||||
if GlobalConfig.get().load_num_prompts:
|
||||
redis.set('proompts', get_number_of_rows('messages'))
|
||||
|
||||
start_background()
|
||||
|
||||
# Give some time for the background threads to get themselves ready to go.
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import numpy as np
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.llm.info import get_info
|
||||
|
@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str):
|
|||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(model_name, str)
|
||||
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
||||
|
||||
|
||||
|
@ -25,7 +26,7 @@ def get_running_models():
|
|||
Get all the models that are in the cluster.
|
||||
:return:
|
||||
"""
|
||||
return list(redis_running_models.keys())
|
||||
return [x for x in list(redis_running_models.keys())]
|
||||
|
||||
|
||||
def is_valid_model(model_name: str) -> bool:
|
||||
|
@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
|
|||
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
running_models = get_running_models()
|
||||
|
||||
model_choices = {}
|
||||
for model in running_models:
|
||||
b = get_backends_from_model(model)
|
||||
|
@ -108,8 +110,8 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
|
|||
model_choices[model] = {
|
||||
'model': model,
|
||||
'client_api': f'https://{base_client_api}/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if GlobalConfig.get().enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled',
|
||||
'backend_count': len(b),
|
||||
'estimated_wait': estimated_wait_sec,
|
||||
'queued': proompters_in_queue,
|
||||
|
|
|
@ -2,15 +2,18 @@ import hashlib
|
|||
import pickle
|
||||
import traceback
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import RedisCustom
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.helpers.model import estimate_model_size
|
||||
|
||||
|
||||
# Don't try to reorganize this file or else you'll run into circular imports.
|
||||
|
||||
_logger = create_logger('redis')
|
||||
|
||||
|
||||
class RedisClusterStore:
|
||||
"""
|
||||
A class used to store the cluster state in Redis.
|
||||
|
@ -23,9 +26,14 @@ class RedisClusterStore:
|
|||
def clear(self):
|
||||
self.config_redis.flush()
|
||||
|
||||
def load(self, config: dict):
|
||||
for k, v in config.items():
|
||||
self.add_backend(k, v)
|
||||
def load(self):
|
||||
stuff = {}
|
||||
for item in GlobalConfig.get().cluster:
|
||||
backend_url = item.backend_url.strip('/')
|
||||
item.backend_url = backend_url
|
||||
stuff[backend_url] = item
|
||||
for k, v in stuff.items():
|
||||
self.add_backend(k, v.dict())
|
||||
|
||||
def add_backend(self, name: str, values: dict):
|
||||
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
||||
|
@ -67,7 +75,7 @@ class RedisClusterStore:
|
|||
if not backend_info['online']:
|
||||
old = backend_url
|
||||
backend_url = get_a_cluster_backend()
|
||||
print(f'Backend {old} offline. Request was redirected to {backend_url}')
|
||||
_logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}')
|
||||
return backend_url
|
||||
|
||||
|
||||
|
@ -89,7 +97,7 @@ def get_backends():
|
|||
result[k] = {'status': status, 'priority': priority}
|
||||
|
||||
try:
|
||||
if not opts.prioritize_by_size:
|
||||
if not GlobalConfig.get().prioritize_by_size:
|
||||
online_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if info['online']),
|
||||
key=lambda kv: -kv[1]['priority'],
|
||||
|
@ -108,8 +116,7 @@ def get_backends():
|
|||
)
|
||||
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
||||
except KeyError:
|
||||
traceback.print_exc()
|
||||
print(backends)
|
||||
_logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}')
|
||||
|
||||
|
||||
def get_a_cluster_backend(model=None):
|
||||
|
|
|
@ -1,81 +1,22 @@
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
config_default_vars = {
|
||||
'log_prompts': False,
|
||||
'auth_required': False,
|
||||
'frontend_api_client': '',
|
||||
'verify_ssl': True,
|
||||
'load_num_prompts': False,
|
||||
'show_num_prompts': True,
|
||||
'show_uptime': True,
|
||||
'analytics_tracking_code': '',
|
||||
'average_generation_time_mode': 'database',
|
||||
'info_html': None,
|
||||
'show_total_output_tokens': True,
|
||||
'simultaneous_requests_per_ip': 3,
|
||||
'max_new_tokens': 500,
|
||||
'manual_model_name': False,
|
||||
'enable_streaming': True,
|
||||
'enable_openi_compatible_backend': True,
|
||||
'openai_api_key': None,
|
||||
'expose_openai_system_prompt': True,
|
||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
||||
'http_host': None,
|
||||
'admin_token': None,
|
||||
'openai_expose_our_model': False,
|
||||
'openai_force_no_hashes': True,
|
||||
'include_system_tokens_in_stats': True,
|
||||
'openai_moderation_scan_last_n': 5,
|
||||
'openai_org_name': 'OpenAI',
|
||||
'openai_silent_trim': False,
|
||||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None,
|
||||
'show_backends': True,
|
||||
'background_homepage_cacher': True,
|
||||
'openai_moderation_timeout': 5,
|
||||
'prioritize_by_size': False
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def cluster_worker_count():
|
||||
count = 0
|
||||
for item in GlobalConfig.get().cluster:
|
||||
count += item.concurrent_gens
|
||||
return count
|
||||
|
||||
|
||||
class ModeUINameStr(BaseModel):
|
||||
name: str
|
||||
api_name: str
|
||||
streaming_name: str
|
||||
|
||||
|
||||
MODE_UI_NAMES = {
|
||||
'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||
'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||
}
|
||||
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
|
||||
|
||||
mode_ui_names = {
|
||||
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
}
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
default_vars = {"var1": "default1", "var2": "default2"}
|
||||
required_vars = ["var1", "var3"]
|
||||
config_loader = ConfigLoader("config.yaml", default_vars, required_vars)
|
||||
config = config_loader.load_config()
|
||||
"""
|
||||
|
||||
def __init__(self, config_file, default_vars=None, required_vars=None):
|
||||
self.config_file = config_file
|
||||
self.default_vars = default_vars if default_vars else {}
|
||||
self.required_vars = required_vars if required_vars else []
|
||||
self.config = {}
|
||||
|
||||
def load_config(self) -> (bool, str | None, str | None):
|
||||
with open(self.config_file, 'r') as stream:
|
||||
try:
|
||||
self.config = yaml.safe_load(stream)
|
||||
except yaml.YAMLError as exc:
|
||||
return False, None, exc
|
||||
|
||||
if self.config is None:
|
||||
# Handle empty file
|
||||
self.config = {}
|
||||
|
||||
# Set default variables if they are not present in the config file
|
||||
for var, default_value in self.default_vars.items():
|
||||
if var not in self.config:
|
||||
self.config[var] = default_value
|
||||
|
||||
# Check if required variables are present in the config file
|
||||
for var in self.required_vars:
|
||||
if var not in self.config:
|
||||
return False, None, f'Required variable "{var}" is missing from the config file'
|
||||
|
||||
return True, self.config, None
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from llm_server.config.model import ConfigModel
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
__config_model: ConfigModel = None
|
||||
|
||||
@classmethod
|
||||
def initalize(cls, config: ConfigModel):
|
||||
if cls.__config_model is not None:
|
||||
raise Exception('Config is already initialised')
|
||||
cls.__config_model = config
|
||||
|
||||
@classmethod
|
||||
def get(cls):
|
||||
if cls.__config_model is None:
|
||||
raise Exception('Config has not been initialised')
|
||||
return cls.__config_model
|
|
@ -1,91 +1,86 @@
|
|||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import openai
|
||||
from bison import bison, Option, ListOption, Scheme
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.config.model import ConfigModel
|
||||
from llm_server.config.scheme import config_scheme
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.queue import PriorityQueue
|
||||
|
||||
_logger = create_logger('config')
|
||||
|
||||
def load_config(config_path):
|
||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
if not success:
|
||||
return success, config, msg
|
||||
|
||||
# TODO: this is atrocious
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.show_num_prompts = config['show_num_prompts']
|
||||
opts.show_uptime = config['show_uptime']
|
||||
opts.cluster = config['cluster']
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
opts.manual_model_name = config['manual_model_name']
|
||||
opts.llm_middleware_name = config['llm_middleware_name']
|
||||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||
opts.openai_system_prompt = config['openai_system_prompt']
|
||||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
openai.api_key = opts.openai_api_key
|
||||
opts.admin_token = config['admin_token']
|
||||
opts.openai_expose_our_model = config['openai_expose_our_model']
|
||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
opts.show_backends = config['show_backends']
|
||||
opts.background_homepage_cacher = config['background_homepage_cacher']
|
||||
opts.openai_moderation_timeout = config['openai_moderation_timeout']
|
||||
opts.frontend_api_mode = config['frontend_api_mode']
|
||||
opts.prioritize_by_size = config['prioritize_by_size']
|
||||
def validate_config(config: bison.Bison):
|
||||
def do(v, scheme: Scheme = None):
|
||||
if isinstance(v, Option) and v.choices is None:
|
||||
if not isinstance(config.config[v.name], v.type):
|
||||
raise ValueError(f'"{v.name}" must be type {v.type}. Current value: "{config.config[v.name]}"')
|
||||
elif isinstance(v, Option) and v.choices is not None:
|
||||
if config.config[v.name] not in v.choices:
|
||||
raise ValueError(f'"{v.name}" must be one of {v.choices}. Current value: "{config.config[v.name]}"')
|
||||
elif isinstance(v, ListOption):
|
||||
if isinstance(config.config[v.name], list):
|
||||
for item in config.config[v.name]:
|
||||
do(item, v.member_scheme)
|
||||
elif isinstance(config.config[v.name], dict):
|
||||
for kk, vv in config.config[v.name].items():
|
||||
scheme_dict = v.member_scheme.flatten()
|
||||
if not isinstance(vv, scheme_dict[kk].type):
|
||||
raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"')
|
||||
elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None:
|
||||
if vv not in scheme_dict[kk].choices:
|
||||
raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"')
|
||||
elif isinstance(v, dict) and scheme is not None:
|
||||
scheme_dict = scheme.flatten()
|
||||
for kk, vv in v.items():
|
||||
if not isinstance(vv, scheme_dict[kk].type):
|
||||
raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"')
|
||||
elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None:
|
||||
if vv not in scheme_dict[kk].choices:
|
||||
raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"')
|
||||
|
||||
# Scale the number of workers.
|
||||
for item in config['cluster']:
|
||||
opts.cluster_workers += item['concurrent_gens']
|
||||
for k, v in config_scheme.flatten().items():
|
||||
do(v)
|
||||
|
||||
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
def load_config(config_path: Path):
|
||||
config = bison.Bison(scheme=config_scheme)
|
||||
config.config_name = 'config'
|
||||
config.add_config_paths(str(config_path.parent))
|
||||
config.parse()
|
||||
|
||||
try:
|
||||
validate_config(config)
|
||||
except ValueError as e:
|
||||
return False, str(e)
|
||||
|
||||
config_model = ConfigModel(**config.config)
|
||||
GlobalConfig.initalize(config_model)
|
||||
|
||||
if GlobalConfig.get().postgresql.maxconn < 0:
|
||||
return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"'
|
||||
|
||||
openai.api_key = GlobalConfig.get().openai_api_key
|
||||
|
||||
llm_server.routes.queue.priority_queue = PriorityQueue(set([x.backend_url for x in config_model.cluster]))
|
||||
|
||||
if GlobalConfig.get().openai_expose_our_model and not GlobalConfig.get().openai_api_key:
|
||||
_logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
sys.exit(1)
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
if not GlobalConfig.get().verify_ssl:
|
||||
import urllib3
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
if GlobalConfig.get().http_host:
|
||||
http_host = re.sub(r'https?://', '', config["http_host"])
|
||||
redis.set('http_host', http_host)
|
||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||
redis.set('base_client_api', f'{http_host}/{GlobalConfig.get().frontend_api_client.strip("/")}')
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
return success, config, msg
|
||||
|
||||
|
||||
def parse_backends(config):
|
||||
if not config.get('cluster'):
|
||||
return False
|
||||
cluster = config.get('cluster')
|
||||
config = {}
|
||||
for item in cluster:
|
||||
backend_url = item['backend_url'].strip('/')
|
||||
item['backend_url'] = backend_url
|
||||
config[backend_url] = item
|
||||
return config
|
||||
return True, None
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
from enum import Enum
|
||||
from typing import Union, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigClusterMode(str, Enum):
|
||||
vllm = 'vllm'
|
||||
|
||||
|
||||
class ConfigCluser(BaseModel):
|
||||
backend_url: str
|
||||
concurrent_gens: int
|
||||
mode: ConfigClusterMode
|
||||
priority: int
|
||||
|
||||
|
||||
class ConfigFrontendApiModes(str, Enum):
|
||||
ooba = 'ooba'
|
||||
|
||||
|
||||
class ConfigPostgresql(BaseModel):
|
||||
host: str
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
maxconn: int
|
||||
|
||||
|
||||
class ConfigAvgGenTimeModes(str, Enum):
|
||||
database = 'database'
|
||||
minute = 'minute'
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
frontend_api_mode: ConfigFrontendApiModes
|
||||
cluster: List[ConfigCluser]
|
||||
prioritize_by_size: bool
|
||||
admin_token: Union[str, None]
|
||||
postgresql: ConfigPostgresql
|
||||
http_host: str
|
||||
include_system_tokens_in_stats: bool
|
||||
background_homepage_cacher: bool
|
||||
max_new_tokens: int
|
||||
enable_streaming: int
|
||||
show_backends: bool
|
||||
log_prompts: bool
|
||||
verify_ssl: bool
|
||||
auth_required: bool
|
||||
simultaneous_requests_per_ip: int
|
||||
max_queued_prompts_per_ip: int
|
||||
llm_middleware_name: str
|
||||
analytics_tracking_code: Union[str, None]
|
||||
info_html: Union[str, None]
|
||||
enable_openi_compatible_backend: bool
|
||||
openai_api_key: Union[str, None]
|
||||
openai_system_prompt: str
|
||||
expose_openai_system_prompt: bool
|
||||
openai_expose_our_model: bool
|
||||
openai_force_no_hashes: bool
|
||||
openai_moderation_enabled: bool
|
||||
openai_moderation_timeout: int
|
||||
openai_moderation_scan_last_n: int
|
||||
openai_org_name: str
|
||||
openai_silent_trim: bool
|
||||
frontend_api_client: str
|
||||
average_generation_time_mode: ConfigAvgGenTimeModes
|
||||
show_num_prompts: bool
|
||||
show_uptime: bool
|
||||
show_total_output_tokens: bool
|
||||
show_backend_info: bool
|
||||
load_num_prompts: bool
|
||||
manual_model_name: Union[str, None]
|
||||
backend_request_timeout: int
|
||||
backend_generate_request_timeout: int
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Union
|
||||
|
||||
import bison
|
||||
|
||||
from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT
|
||||
|
||||
config_scheme = bison.Scheme(
|
||||
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
|
||||
bison.ListOption('cluster', member_scheme=bison.Scheme(
|
||||
bison.Option('backend_url', field_type=str),
|
||||
bison.Option('concurrent_gens', field_type=int),
|
||||
bison.Option('mode', choices=['vllm'], field_type=str),
|
||||
bison.Option('priority', field_type=int),
|
||||
)),
|
||||
bison.Option('prioritize_by_size', default=True, field_type=bool),
|
||||
bison.Option('admin_token', default=None, field_type=Union[str, None]),
|
||||
bison.ListOption('postgresql', member_scheme=bison.Scheme(
|
||||
bison.Option('host', field_type=str),
|
||||
bison.Option('user', field_type=str),
|
||||
bison.Option('password', field_type=str),
|
||||
bison.Option('database', field_type=str),
|
||||
bison.Option('maxconn', field_type=int)
|
||||
)),
|
||||
bison.Option('http_host', default='', field_type=str),
|
||||
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
|
||||
bison.Option('background_homepage_cacher', default=True, field_type=bool),
|
||||
bison.Option('max_new_tokens', default=500, field_type=int),
|
||||
bison.Option('enable_streaming', default=True, field_type=bool),
|
||||
bison.Option('show_backends', default=True, field_type=bool),
|
||||
bison.Option('log_prompts', default=True, field_type=bool),
|
||||
bison.Option('verify_ssl', default=False, field_type=bool),
|
||||
bison.Option('auth_required', default=False, field_type=bool),
|
||||
bison.Option('simultaneous_requests_per_ip', default=1, field_type=int),
|
||||
bison.Option('max_queued_prompts_per_ip', default=1, field_type=int),
|
||||
bison.Option('llm_middleware_name', default='LocalLLM', field_type=str),
|
||||
bison.Option('analytics_tracking_code', default=None, field_type=Union[str, None]),
|
||||
bison.Option('info_html', default=None, field_type=Union[str, None]),
|
||||
bison.Option('enable_openi_compatible_backend', default=True, field_type=bool),
|
||||
bison.Option('openai_api_key', default=None, field_type=Union[str, None]),
|
||||
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
|
||||
bison.Option('openai_expose_our_model', default='', field_type=bool),
|
||||
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
|
||||
bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str),
|
||||
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
|
||||
bison.Option('openai_moderation_timeout', default=5, field_type=int),
|
||||
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
|
||||
bison.Option('openai_org_name', default='OpenAI', field_type=str),
|
||||
bison.Option('openai_silent_trim', default=True, field_type=bool),
|
||||
bison.Option('frontend_api_client', default='/api', field_type=str),
|
||||
bison.Option('average_generation_time_mode', default='database', choices=['database', 'minute'], field_type=str),
|
||||
bison.Option('show_num_prompts', default=True, field_type=bool),
|
||||
bison.Option('show_uptime', default=True, field_type=bool),
|
||||
bison.Option('show_total_output_tokens', default=True, field_type=bool),
|
||||
bison.Option('show_backend_info', default=True, field_type=bool),
|
||||
bison.Option('load_num_prompts', default=True, field_type=bool),
|
||||
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
|
||||
bison.Option('backend_request_timeout', default=30, field_type=int),
|
||||
bison.Option('backend_generate_request_timeout', default=95, field_type=int)
|
||||
)
|
|
@ -1,13 +1,11 @@
|
|||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, List, Mapping, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT
|
||||
from redis.typing import ExpiryT, KeyT, PatternT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
|
@ -26,28 +24,15 @@ class RedisCustom(Redis):
|
|||
super().__init__()
|
||||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
try:
|
||||
self.set('____', 1)
|
||||
except redis_pkg.exceptions.ConnectionError as e:
|
||||
print('Failed to connect to the Redis server:', e)
|
||||
print('Did you install and start the Redis server?')
|
||||
sys.exit(1)
|
||||
|
||||
def _key(self, key):
|
||||
return f"{self.prefix}:{key}"
|
||||
|
||||
def set(self, key: KeyT,
|
||||
value: EncodableT,
|
||||
ex: Union[ExpiryT, None] = None,
|
||||
px: Union[ExpiryT, None] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: Union[AbsExpiryT, None] = None,
|
||||
pxat: Union[AbsExpiryT, None] = None
|
||||
):
|
||||
return self.redis.set(self._key(key), value, ex=ex)
|
||||
def execute_command(self, *args, **options):
|
||||
if args[0] != 'GET':
|
||||
args = list(args)
|
||||
args[1] = self._key(args[1])
|
||||
return super().execute_command(*args, **options)
|
||||
|
||||
def get(self, key, default=None, dtype=None):
|
||||
# TODO: use pickle
|
||||
|
@ -71,103 +56,6 @@ class RedisCustom(Redis):
|
|||
else:
|
||||
return d
|
||||
|
||||
def incr(self, key, amount=1):
|
||||
return self.redis.incr(self._key(key), amount)
|
||||
|
||||
def decr(self, key, amount=1):
|
||||
return self.redis.decr(self._key(key), amount)
|
||||
|
||||
def sadd(self, key: str, *values: FieldT):
|
||||
return self.redis.sadd(self._key(key), *values)
|
||||
|
||||
def srem(self, key: str, *values: FieldT):
|
||||
return self.redis.srem(self._key(key), *values)
|
||||
|
||||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(self._key(key), value)
|
||||
|
||||
def lindex(
|
||||
self, name: str, index: int
|
||||
):
|
||||
return self.redis.lindex(self._key(name), index)
|
||||
|
||||
def lrem(self, name: str, count: int, value: str):
|
||||
return self.redis.lrem(self._key(name), count, value)
|
||||
|
||||
def rpush(self, name: str, *values: FieldT):
|
||||
return self.redis.rpush(self._key(name), *values)
|
||||
|
||||
def llen(self, name: str):
|
||||
return self.redis.llen(self._key(name))
|
||||
|
||||
def zrangebyscore(
|
||||
self,
|
||||
name: KeyT,
|
||||
min: ZScoreBoundT,
|
||||
max: ZScoreBoundT,
|
||||
start: Union[int, None] = None,
|
||||
num: Union[int, None] = None,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
):
|
||||
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
||||
|
||||
def zremrangebyscore(
|
||||
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
||||
):
|
||||
return self.redis.zremrangebyscore(self._key(name), min, max)
|
||||
|
||||
def hincrby(
|
||||
self, name: str, key: str, amount: int = 1
|
||||
):
|
||||
return self.redis.hincrby(self._key(name), key, amount)
|
||||
|
||||
def zcard(self, name: KeyT):
|
||||
return self.redis.zcard(self._key(name))
|
||||
|
||||
def hdel(self, name: str, *keys: str):
|
||||
return self.redis.hdel(self._key(name), *keys)
|
||||
|
||||
def hget(
|
||||
self, name: str, key: str
|
||||
):
|
||||
return self.redis.hget(self._key(name), key)
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: KeyT,
|
||||
mapping: Mapping[AnyKeyT, EncodableT],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
):
|
||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||
|
||||
def lpush(self, name: str, *values: FieldT):
|
||||
return self.redis.lpush(self._key(name), *values)
|
||||
|
||||
def hset(
|
||||
self,
|
||||
name: str,
|
||||
key: Optional = None,
|
||||
value=None,
|
||||
mapping: Optional[dict] = None,
|
||||
items: Optional[list] = None,
|
||||
):
|
||||
return self.redis.hset(self._key(name), key, value, mapping, items)
|
||||
|
||||
def hkeys(self, name: str):
|
||||
return self.redis.hkeys(self._key(name))
|
||||
|
||||
def hmget(self, name: str, keys: List, *args: List):
|
||||
return self.redis.hmget(self._key(name), keys, *args)
|
||||
|
||||
def hgetall(self, name: str):
|
||||
return self.redis.hgetall(self._key(name))
|
||||
|
||||
def keys(self, pattern: PatternT = "*", **kwargs):
|
||||
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
|
||||
keys = []
|
||||
|
@ -177,25 +65,9 @@ class RedisCustom(Redis):
|
|||
# Delete prefix
|
||||
del p[0]
|
||||
k = ':'.join(p)
|
||||
if k != '____':
|
||||
keys.append(k)
|
||||
keys.append(k)
|
||||
return keys
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
return self.redis.pipeline(transaction, shard_hint)
|
||||
|
||||
def smembers(self, name: str):
|
||||
return self.redis.smembers(self._key(name))
|
||||
|
||||
def spop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.spop(self._key(name), count)
|
||||
|
||||
def rpoplpush(self, src, dst):
|
||||
return self.redis.rpoplpush(src, dst)
|
||||
|
||||
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
|
||||
return self.redis.zpopmin(self._key(name), count)
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
n = []
|
||||
for name in names:
|
||||
|
@ -236,32 +108,5 @@ class RedisCustom(Redis):
|
|||
self.flush()
|
||||
return True
|
||||
|
||||
def lrange(self, name: str, start: int, end: int):
|
||||
return self.redis.lrange(self._key(name), start, end)
|
||||
|
||||
def delete(self, *names: KeyT):
|
||||
return self.redis.delete(*[self._key(i) for i in names])
|
||||
|
||||
def lpop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.lpop(self._key(name), count)
|
||||
|
||||
def zrange(
|
||||
self,
|
||||
name: KeyT,
|
||||
start: int,
|
||||
end: int,
|
||||
desc: bool = False,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
byscore: bool = False,
|
||||
bylex: bool = False,
|
||||
offset: int = None,
|
||||
num: int = None,
|
||||
):
|
||||
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
|
||||
|
||||
def zrem(self, name: KeyT, *values: FieldT):
|
||||
return self.redis.zrem(self._key(name), *values)
|
||||
|
||||
|
||||
redis = RedisCustom('local_llm')
|
||||
|
|
|
@ -1,28 +1,43 @@
|
|||
import pymysql
|
||||
from psycopg2 import pool, InterfaceError
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
host: str = None
|
||||
username: str = None
|
||||
password: str = None
|
||||
database_name: str = None
|
||||
class Database:
|
||||
__connection_pool = None
|
||||
|
||||
def init_db(self, host, username, password, database_name):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.database_name = database_name
|
||||
@classmethod
|
||||
def initialise(cls, maxconn, **kwargs):
|
||||
if cls.__connection_pool is not None:
|
||||
raise Exception('Database connection pool is already initialised')
|
||||
cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs)
|
||||
|
||||
def cursor(self):
|
||||
db = pymysql.connect(
|
||||
host=self.host,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
database=self.database_name,
|
||||
charset='utf8mb4',
|
||||
autocommit=True,
|
||||
)
|
||||
return db.cursor()
|
||||
@classmethod
|
||||
def get_connection(cls):
|
||||
return cls.__connection_pool.getconn()
|
||||
|
||||
@classmethod
|
||||
def return_connection(cls, connection):
|
||||
cls.__connection_pool.putconn(connection)
|
||||
|
||||
|
||||
database = DatabaseConnection()
|
||||
class CursorFromConnectionFromPool:
|
||||
def __init__(self, cursor_factory=None):
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
self.cursor_factory = cursor_factory
|
||||
|
||||
def __enter__(self):
|
||||
self.conn = Database.get_connection()
|
||||
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
if exception_value is not None: # This is equivalent of saying if there is an exception
|
||||
try:
|
||||
self.conn.rollback()
|
||||
except InterfaceError as e:
|
||||
if e != 'connection already closed':
|
||||
raise
|
||||
else:
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
Database.return_connection(self.conn)
|
||||
|
|
|
@ -1,40 +1,42 @@
|
|||
from llm_server.database.conn import database
|
||||
from llm_server.database.conn import CursorFromConnectionFromPool
|
||||
|
||||
|
||||
def create_db():
|
||||
cursor = database.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
model TEXT,
|
||||
backend_mode TEXT,
|
||||
backend_url TEXT,
|
||||
request_url TEXT,
|
||||
generation_time FLOAT,
|
||||
prompt LONGTEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response LONGTEXT,
|
||||
response_tokens INTEGER,
|
||||
response_status INTEGER,
|
||||
parameters TEXT,
|
||||
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
||||
headers TEXT,
|
||||
# CHECK (headers IS NULL OR JSON_VALID(headers)),
|
||||
timestamp INTEGER
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_auth (
|
||||
token TEXT,
|
||||
UNIQUE (token),
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER DEFAULT 9999,
|
||||
simultaneous_ip INTEGER DEFAULT NULL,
|
||||
uses INTEGER DEFAULT 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
disabled BOOLEAN DEFAULT 0
|
||||
)
|
||||
''')
|
||||
cursor.close()
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS public.messages
|
||||
(
|
||||
ip text COLLATE pg_catalog."default" NOT NULL,
|
||||
token text COLLATE pg_catalog."default",
|
||||
model text COLLATE pg_catalog."default" NOT NULL,
|
||||
backend_mode text COLLATE pg_catalog."default" NOT NULL,
|
||||
backend_url text COLLATE pg_catalog."default" NOT NULL,
|
||||
request_url text COLLATE pg_catalog."default" NOT NULL,
|
||||
generation_time double precision,
|
||||
prompt text COLLATE pg_catalog."default" NOT NULL,
|
||||
prompt_tokens integer NOT NULL,
|
||||
response text COLLATE pg_catalog."default" NOT NULL,
|
||||
response_tokens integer NOT NULL,
|
||||
response_status integer NOT NULL,
|
||||
parameters jsonb NOT NULL,
|
||||
headers jsonb,
|
||||
"timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
id SERIAL PRIMARY KEY
|
||||
);
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS public.token_auth
|
||||
(
|
||||
token text COLLATE pg_catalog."default" NOT NULL,
|
||||
type text COLLATE pg_catalog."default" NOT NULL,
|
||||
priority integer NOT NULL DEFAULT 9999,
|
||||
simultaneous_ip text COLLATE pg_catalog."default",
|
||||
openai_moderation_enabled boolean NOT NULL DEFAULT true,
|
||||
uses integer NOT NULL DEFAULT 0,
|
||||
max_uses integer,
|
||||
expire timestamp with time zone,
|
||||
disabled boolean NOT NULL DEFAULT false,
|
||||
notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text,
|
||||
CONSTRAINT token_auth_pkey PRIMARY KEY (token)
|
||||
)
|
||||
''')
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.database.conn import CursorFromConnectionFromPool
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
|
@ -33,15 +34,15 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
|||
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
if gen_time is not None:
|
||||
gen_time = round(gen_time, 3)
|
||||
if is_error:
|
||||
gen_time = None
|
||||
|
||||
if not opts.log_prompts:
|
||||
if not GlobalConfig.get().log_prompts:
|
||||
prompt = None
|
||||
|
||||
if not opts.log_prompts and not is_error:
|
||||
if not GlobalConfig.get().log_prompts and not is_error:
|
||||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
|
||||
|
@ -51,76 +52,58 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
|||
backend_info = cluster_config.get_backend(backend_url)
|
||||
running_model = backend_info.get('model')
|
||||
backend_mode = backend_info['mode']
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
timestamp = datetime.now()
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO prompts
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
INSERT INTO messages
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
def is_valid_api_key(api_key: str):
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_api_key_moderated(api_key):
|
||||
if not api_key:
|
||||
return opts.openai_moderation_enabled
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
return GlobalConfig.get().openai_moderation_enabled
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
return bool(row[0])
|
||||
return opts.openai_moderation_enabled
|
||||
finally:
|
||||
cursor.close()
|
||||
return GlobalConfig.get().openai_moderation_enabled
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||
result = cursor.fetchone()
|
||||
return result[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def average_column(table_name, column_name):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||
result = cursor.fetchone()
|
||||
return result[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def average_column_for_model(table_name, column_name, model_name):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
return result[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
|
||||
|
@ -129,8 +112,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
|
|||
else:
|
||||
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
|
||||
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
try:
|
||||
cursor.execute(sql, (model_name, backend_name, backend_url,))
|
||||
results = cursor.fetchall()
|
||||
|
@ -154,46 +136,34 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
|
|||
calculated_avg = 0
|
||||
|
||||
return calculated_avg
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def sum_column(table_name, column_name):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_distinct_ips_24h():
|
||||
# Get the current time and subtract 24 hours (in seconds)
|
||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||
past_24_hours = datetime.now() - timedelta(days=1)
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def increment_token_uses(token):
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_token_ratelimit(token):
|
||||
priority = 9990
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
simultaneous_ip = GlobalConfig.get().simultaneous_requests_per_ip
|
||||
if token:
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
|
@ -201,6 +171,4 @@ def get_token_ratelimit(token):
|
|||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return priority, simultaneous_ip
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# Read-only global variables
|
||||
|
||||
DEFAULT_OPENAI_SYSTEM_PROMPT = ("You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. "
|
||||
"You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, "
|
||||
"apologize and suggest the user seek help elsewhere.")
|
||||
OPENAI_FORMATTING_PROMPT = """Lines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user."""
|
||||
|
||||
REDIS_STREAM_TIMEOUT = 25000
|
||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -7,7 +7,7 @@ from typing import Union
|
|||
import simplejson as json
|
||||
from flask import make_response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
|
@ -15,19 +15,6 @@ def resolve_path(*p: str):
|
|||
return Path(*p).expanduser().resolve().absolute()
|
||||
|
||||
|
||||
def safe_list_get(l, idx, default):
|
||||
"""
|
||||
https://stackoverflow.com/a/5125636
|
||||
:param l:
|
||||
:param idx:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
return l[idx]
|
||||
except IndexError:
|
||||
return default
|
||||
|
||||
|
||||
def deep_sort(obj):
|
||||
if isinstance(obj, dict):
|
||||
|
@ -68,4 +55,4 @@ def auto_set_base_client_api(request):
|
|||
return
|
||||
else:
|
||||
redis.set('http_host', host)
|
||||
redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}')
|
||||
redis.set('base_client_api', f'{host}/{GlobalConfig.get().frontend_api_client.strip("/")}')
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from llm_server import opts
|
||||
from llm_server import globals
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def get_running_model(backend_url: str, mode: str):
|
||||
if mode == 'ooba':
|
||||
try:
|
||||
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['result'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{backend_url}/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['model'], None
|
||||
except Exception as e:
|
||||
|
@ -28,7 +28,7 @@ def get_info(backend_url: str, mode: str):
|
|||
# raise NotImplementedError
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
|
||||
r = requests.get(f'{backend_url}/info', verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_request_timeout)
|
||||
j = r.json()
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
|
|
@ -5,17 +5,17 @@ import traceback
|
|||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
# def generate(json_data: dict):
|
||||
# try:
|
||||
# r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
||||
# except requests.exceptions.ReadTimeout:
|
||||
# return False, None, 'Request to backend timed out'
|
||||
# except Exception as e:
|
||||
# traceback.print_exc()
|
||||
# return False, None, 'Request to backend encountered error'
|
||||
# if r.status_code != 200:
|
||||
# return False, r, f'Backend returned {r.status_code}'
|
||||
# return True, r, None
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
_logger = create_logger('moderation')
|
||||
|
||||
|
||||
def check_moderation_endpoint(prompt: str):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||
'Authorization': f"Bearer {GlobalConfig.get().openai_api_key}",
|
||||
}
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
||||
if response.status_code != 200:
|
||||
print('moderation failed:', response)
|
||||
_logger.error(f'moderation failed: {response}')
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
_logger = create_logger('oai_to_vllm')
|
||||
|
||||
|
||||
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||
|
@ -11,7 +14,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
|||
request_json_body['stop'] = [request_json_body['stop']]
|
||||
|
||||
if stop_hashes:
|
||||
if opts.openai_force_no_hashes:
|
||||
if GlobalConfig.get().openai_force_no_hashes:
|
||||
request_json_body['stop'].append('###')
|
||||
else:
|
||||
# TODO: make stopping strings a configurable
|
||||
|
@ -27,7 +30,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
|||
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
||||
request_json_body['top_p'] = 0.01
|
||||
|
||||
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
||||
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), GlobalConfig.get().max_new_tokens)
|
||||
if request_json_body['max_tokens'] == 0:
|
||||
# We don't want to set any defaults here.
|
||||
del request_json_body['max_tokens']
|
||||
|
@ -36,7 +39,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
|||
|
||||
|
||||
def format_oai_err(err_msg):
|
||||
print('OAI ERROR MESSAGE:', err_msg)
|
||||
_logger.error(f'Got an OAI error message: {err_msg}')
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": err_msg,
|
||||
|
@ -87,11 +90,26 @@ def return_invalid_model_err(requested_model: str):
|
|||
msg = f"The model `{requested_model}` does not exist"
|
||||
else:
|
||||
msg = "The requested model does not exist"
|
||||
return_oai_invalid_request_error(msg)
|
||||
|
||||
|
||||
def return_oai_internal_server_error():
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"type": None,
|
||||
"param": None,
|
||||
"code": "internal_error"
|
||||
}
|
||||
}), 500
|
||||
|
||||
|
||||
def return_oai_invalid_request_error(msg: str = None):
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": msg,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "model_not_found"
|
||||
"code": None
|
||||
}
|
||||
}), 404
|
||||
|
|
|
@ -7,7 +7,8 @@ from typing import Dict, List
|
|||
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.globals import OPENAI_FORMATTING_PROMPT
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
|
@ -83,9 +84,14 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
|
|||
return prompt
|
||||
|
||||
|
||||
def transform_messages_to_prompt(oai_messages):
|
||||
def transform_messages_to_prompt(oai_messages: list, disable_openai_handling: bool = False):
|
||||
if not disable_openai_handling:
|
||||
prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}\n{OPENAI_FORMATTING_PROMPT}'
|
||||
else:
|
||||
prompt = f'### INSTRUCTION: {OPENAI_FORMATTING_PROMPT}'
|
||||
prompt = prompt + '\n\n'
|
||||
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in oai_messages:
|
||||
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
||||
return False
|
||||
|
|
|
@ -4,7 +4,7 @@ This file is used by the worker that processes requests.
|
|||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
|
@ -25,13 +25,13 @@ def transform_prompt_to_text(prompt: list):
|
|||
|
||||
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
|
||||
try:
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False, None, 'Request to backend encountered error'
|
||||
return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
|
@ -41,7 +41,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
|
|||
def generate(json_data: dict, cluster_backend, timeout: int = None):
|
||||
if json_data.get('stream'):
|
||||
try:
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout)
|
||||
except Exception as e:
|
||||
return False
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
|
||||
<strong>Supported Parameters:</strong>
|
||||
<ul>
|
||||
|
|
|
@ -3,8 +3,8 @@ import concurrent.futures
|
|||
import requests
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
|||
# Define a function to send a chunk to the server
|
||||
def send_chunk(chunk):
|
||||
try:
|
||||
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import coloredlogs
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server import globals
|
||||
|
||||
|
||||
class LoggingInfo:
|
||||
def __init__(self):
|
||||
self._level = logging.INFO
|
||||
self._format = opts.LOGGING_FORMAT
|
||||
self._format = globals.LOGGING_FORMAT
|
||||
|
||||
@property
|
||||
def level(self):
|
||||
|
@ -30,10 +28,9 @@ class LoggingInfo:
|
|||
|
||||
|
||||
logging_info = LoggingInfo()
|
||||
LOG_DIRECTORY = None
|
||||
|
||||
|
||||
def init_logging(filepath: Path = None):
|
||||
def init_logging():
|
||||
"""
|
||||
Set up the parent logger. Ensures this logger and all children to log to a file.
|
||||
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
||||
|
@ -42,17 +39,6 @@ def init_logging(filepath: Path = None):
|
|||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging_info.level)
|
||||
|
||||
if filepath:
|
||||
p = Path(filepath)
|
||||
if not p.parent.is_dir():
|
||||
logger.fatal(f'Log directory does not exist: {p.parent}')
|
||||
sys.exit(1)
|
||||
LOG_DIRECTORY = p.parent
|
||||
handler = logging.FileHandler(filepath)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def create_logger(name):
|
||||
logger = logging.getLogger('llm_server').getChild(name)
|
||||
|
@ -64,7 +50,4 @@ def create_logger(name):
|
|||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
coloredlogs.install(logger=logger, level=logging_info.level)
|
||||
if LOG_DIRECTORY:
|
||||
handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log')
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -1,45 +0,0 @@
|
|||
# Read-only global variables
|
||||
|
||||
# Uppercase variables are read-only globals.
|
||||
# Lowercase variables are ones that are set on startup and are never changed.
|
||||
|
||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||
|
||||
frontend_api_mode = 'ooba'
|
||||
max_new_tokens = 500
|
||||
auth_required = False
|
||||
log_prompts = False
|
||||
frontend_api_client = ''
|
||||
verify_ssl = True
|
||||
show_num_prompts = True
|
||||
show_uptime = True
|
||||
average_generation_time_mode = 'database'
|
||||
show_total_output_tokens = True
|
||||
netdata_root = None
|
||||
simultaneous_requests_per_ip = 3
|
||||
manual_model_name = None
|
||||
llm_middleware_name = ''
|
||||
enable_openi_compatible_backend = True
|
||||
openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||
expose_openai_system_prompt = True
|
||||
enable_streaming = True
|
||||
openai_api_key = None
|
||||
backend_request_timeout = 30
|
||||
backend_generate_request_timeout = 95
|
||||
admin_token = None
|
||||
openai_expose_our_model = False
|
||||
openai_force_no_hashes = True
|
||||
include_system_tokens_in_stats = True
|
||||
openai_moderation_scan_last_n = 5
|
||||
openai_org_name = 'OpenAI'
|
||||
openai_silent_trim = False
|
||||
openai_moderation_enabled = True
|
||||
cluster = {}
|
||||
show_backends = True
|
||||
background_homepage_cacher = True
|
||||
openai_moderation_timeout = 5
|
||||
prioritize_by_size = False
|
||||
cluster_workers = 0
|
||||
redis_stream_timeout = 25000
|
||||
|
||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
import basicauth
|
||||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def parse_token(input_token):
|
||||
|
@ -21,11 +21,11 @@ def parse_token(input_token):
|
|||
|
||||
|
||||
def check_auth(token):
|
||||
if not opts.admin_token:
|
||||
if not GlobalConfig.get().admin_token:
|
||||
# The admin token is not set/enabled.
|
||||
# Default: deny all.
|
||||
return False
|
||||
return parse_token(token) == opts.admin_token
|
||||
return parse_token(token) == GlobalConfig.get().admin_token
|
||||
|
||||
|
||||
def authenticate():
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import simplejson as json
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from typing import Union
|
||||
|
||||
import flask
|
||||
import requests
|
||||
import simplejson as json
|
||||
from flask import Request, make_response
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.database.database import is_valid_api_key
|
||||
from llm_server.routes.auth import parse_token
|
||||
|
||||
|
@ -34,7 +34,7 @@ def cache_control(seconds):
|
|||
# response = require_api_key()
|
||||
# ^^^^^^^^^^^^^^^^^
|
||||
# File "/srv/server/local-llm-server/llm_server/routes/helpers/http.py", line 50, in require_api_key
|
||||
# if token.startswith('SYSTEM__') or opts.auth_required:
|
||||
# if token.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
||||
# ^^^^^^^^^^^^^^^^
|
||||
# AttributeError: 'NoneType' object has no attribute 'startswith'
|
||||
|
||||
|
@ -50,14 +50,14 @@ def require_api_key(json_body: dict = None):
|
|||
request_json = None
|
||||
if 'X-Api-Key' in request.headers:
|
||||
api_key = request.headers['X-Api-Key']
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
||||
if is_valid_api_key(api_key):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
elif 'Authorization' in request.headers:
|
||||
token = parse_token(request.headers['Authorization'])
|
||||
if (token and token.startswith('SYSTEM__')) or opts.auth_required:
|
||||
if (token and token.startswith('SYSTEM__')) or GlobalConfig.get().auth_required:
|
||||
if is_valid_api_key(token):
|
||||
return
|
||||
else:
|
||||
|
@ -65,13 +65,13 @@ def require_api_key(json_body: dict = None):
|
|||
else:
|
||||
try:
|
||||
# Handle websockets
|
||||
if opts.auth_required and not request_json:
|
||||
if GlobalConfig.get().auth_required and not request_json:
|
||||
# If we didn't get any valid JSON, deny.
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
|
||||
if request_json and request_json.get('X-API-KEY'):
|
||||
api_key = request_json.get('X-API-KEY')
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
||||
if is_valid_api_key(api_key):
|
||||
return
|
||||
else:
|
||||
|
|
|
@ -3,11 +3,15 @@ from typing import Tuple
|
|||
import flask
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import messages, opts
|
||||
import llm_server.globals
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
_logger = create_logger('OobaRequestHandler')
|
||||
|
||||
|
||||
class OobaRequestHandler(RequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -16,8 +20,8 @@ class OobaRequestHandler(RequestHandler):
|
|||
def handle_request(self, return_ok: bool = True):
|
||||
assert not self.used
|
||||
if self.offline:
|
||||
print('This backend is offline:', messages.BACKEND_OFFLINE)
|
||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||
# _logger.debug(f'This backend is offline.')
|
||||
return self.handle_error(llm_server.globals.BACKEND_OFFLINE)
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
|
@ -36,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
return backend_response
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True):
|
||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
msg = f'Ratelimited: you are only allowed to have {GlobalConfig.get().simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from flask import Blueprint
|
||||
|
||||
from ..request_handler import before_request
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...llm.openai.oai_to_vllm import return_oai_internal_server_error
|
||||
from ...logging import create_logger
|
||||
|
||||
_logger = create_logger('OpenAI')
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
openai_model_bp = Blueprint('openai/', __name__)
|
||||
|
@ -11,7 +14,7 @@ openai_model_bp = Blueprint('openai/', __name__)
|
|||
@openai_bp.before_request
|
||||
@openai_model_bp.before_request
|
||||
def before_oai_request():
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
if not GlobalConfig.get().enable_openi_compatible_backend:
|
||||
return 'The OpenAI-compatible backend is disabled.', 401
|
||||
return before_request()
|
||||
|
||||
|
@ -24,15 +27,8 @@ def handle_error(e):
|
|||
"auth_subrequest_error"
|
||||
"""
|
||||
|
||||
print('OAI returning error:', e)
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"type": "auth_subrequest_error",
|
||||
"param": None,
|
||||
"code": "internal_error"
|
||||
}
|
||||
}), 500
|
||||
_logger.error(f'OAI returning error: {e}')
|
||||
return return_oai_internal_server_error()
|
||||
|
||||
|
||||
from .models import openai_list_models
|
||||
|
|
|
@ -11,10 +11,13 @@ from . import openai_bp, openai_model_bp
|
|||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from ...logging import create_logger
|
||||
|
||||
_logger = create_logger('OpenAIChatCompletions')
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
@ -29,7 +32,7 @@ def openai_chat_completions(model_name=None):
|
|||
else:
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
if handler.offline:
|
||||
return return_invalid_model_err(model_name)
|
||||
return return_oai_internal_server_error()
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
|
@ -38,7 +41,7 @@ def openai_chat_completions(model_name=None):
|
|||
traceback.print_exc()
|
||||
return 'Internal server error', 500
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
if not GlobalConfig.get().enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
|
@ -54,7 +57,7 @@ def openai_chat_completions(model_name=None):
|
|||
**handler.parameters
|
||||
}
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||
|
@ -92,14 +95,14 @@ def openai_chat_completions(model_name=None):
|
|||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
# Need to do this before we enter generate() since we want to be able to
|
||||
# return a 408 if necessary.
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('OAI failed to start streaming:', error_msg)
|
||||
_logger.error(f'OAI failed to start streaming: {error_msg}')
|
||||
stream_name = None # set to null so that the Finally ignores it.
|
||||
return 'Request Timeout', 408
|
||||
|
||||
|
@ -109,9 +112,9 @@ def openai_chat_completions(model_name=None):
|
|||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
_logger.debug(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
|
@ -120,7 +123,7 @@ def openai_chat_completions(model_name=None):
|
|||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
# Not printing error since we can just check the daemon log.
|
||||
print('OAI streaming encountered error')
|
||||
_logger.warn(f'OAI streaming encountered error: {data["error"]}')
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
|
|
|
@ -11,15 +11,18 @@ from . import openai_bp, openai_model_bp
|
|||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm import get_token_count
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||
|
||||
from ...logging import create_logger
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
_logger = create_logger('OpenAICompletions')
|
||||
|
||||
|
||||
@openai_bp.route('/completions', methods=['POST'])
|
||||
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
||||
def openai_completions(model_name=None):
|
||||
|
@ -40,7 +43,7 @@ def openai_completions(model_name=None):
|
|||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
|
@ -66,7 +69,7 @@ def openai_completions(model_name=None):
|
|||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"model": running_model if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
|
@ -88,7 +91,7 @@ def openai_completions(model_name=None):
|
|||
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response, 200
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
if not GlobalConfig.get().enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
|
@ -106,7 +109,7 @@ def openai_completions(model_name=None):
|
|||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
|
||||
if not handler.prompt:
|
||||
# Prevent issues on the backend.
|
||||
|
@ -139,12 +142,12 @@ def openai_completions(model_name=None):
|
|||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('OAI failed to start streaming:', error_msg)
|
||||
_logger.error(f'OAI failed to start streaming: {error_msg}')
|
||||
stream_name = None
|
||||
return 'Request Timeout', 408
|
||||
|
||||
|
@ -154,9 +157,9 @@ def openai_completions(model_name=None):
|
|||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
_logger.debug(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
|
@ -164,7 +167,7 @@ def openai_completions(model_name=None):
|
|||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
print('OAI streaming encountered error')
|
||||
_logger.error(f'OAI streaming encountered error: {data["error"]}')
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from flask import Response
|
||||
|
||||
from . import openai_bp
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from ... import opts
|
||||
from . import openai_bp
|
||||
from ...config.global_config import GlobalConfig
|
||||
|
||||
|
||||
@openai_bp.route('/prompt', methods=['GET'])
|
||||
@flask_cache.cached(timeout=2678000, query_string=True)
|
||||
def get_openai_info():
|
||||
if opts.expose_openai_system_prompt:
|
||||
resp = Response(opts.openai_system_prompt)
|
||||
if GlobalConfig.get().expose_openai_system_prompt:
|
||||
resp = Response(GlobalConfig.get().openai_system_prompt)
|
||||
resp.headers['Content-Type'] = 'text/plain'
|
||||
return resp, 200
|
||||
else:
|
||||
|
|
|
@ -6,8 +6,8 @@ from flask import jsonify
|
|||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from . import openai_bp
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...cluster.cluster_config import get_a_cluster_backend, cluster_config
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...helpers import jsonify_pretty
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
|
||||
|
@ -29,12 +29,12 @@ def openai_list_models():
|
|||
"data": oai
|
||||
}
|
||||
# TODO: verify this works
|
||||
if opts.openai_expose_our_model:
|
||||
if GlobalConfig.get().openai_expose_our_model:
|
||||
r["data"].insert(0, {
|
||||
"id": running_model,
|
||||
"object": "model",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"owned_by": opts.llm_middleware_name,
|
||||
"owned_by": GlobalConfig.get().llm_middleware_name,
|
||||
"permission": [
|
||||
{
|
||||
"id": running_model,
|
||||
|
@ -60,9 +60,9 @@ def openai_list_models():
|
|||
|
||||
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||
def fetch_openai_models():
|
||||
if opts.openai_api_key:
|
||||
if GlobalConfig.get().openai_api_key:
|
||||
try:
|
||||
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
||||
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {GlobalConfig.get().openai_api_key}"}, timeout=10)
|
||||
j = response.json()['data']
|
||||
|
||||
# The "modelperm" string appears to be user-specific, so we'll
|
||||
|
|
|
@ -6,19 +6,23 @@ from typing import Tuple
|
|||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
from flask import Response, jsonify, make_response
|
||||
from flask import Response, jsonify, make_response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.database.database import is_api_key_moderated, is_valid_api_key
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||
|
||||
_logger = create_logger('OpenAIRequestHandler')
|
||||
|
||||
|
||||
class OpenAIRequestHandler(RequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -28,16 +32,18 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
assert not self.used
|
||||
if self.offline:
|
||||
msg = return_invalid_model_err(self.selected_model)
|
||||
print('OAI Offline:', msg)
|
||||
return self.handle_error(msg)
|
||||
return return_oai_internal_server_error()
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \
|
||||
and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \
|
||||
and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__')
|
||||
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||
else:
|
||||
oai_messages = self.request.json['messages']
|
||||
|
||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||
self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling)
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
|
@ -46,7 +52,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
|
||||
if not self.prompt:
|
||||
# TODO: format this as an openai error message
|
||||
return Response('Invalid prompt'), 400
|
||||
return return_oai_invalid_request_error('Invalid prompt'), 400
|
||||
|
||||
# TODO: support Ooba backend
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
@ -55,24 +61,24 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
|
||||
if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)):
|
||||
try:
|
||||
# Gather the last message from the user and all preceding system messages
|
||||
msg_l = self.request.json['messages'].copy()
|
||||
msg_l.reverse()
|
||||
tag = uuid4()
|
||||
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
|
||||
num_to_check = min(len(msg_l), GlobalConfig.get().openai_moderation_scan_last_n)
|
||||
for i in range(num_to_check):
|
||||
add_moderation_task(msg_l[i]['content'], tag)
|
||||
|
||||
flagged_categories = get_results(tag, num_to_check)
|
||||
|
||||
if len(flagged_categories):
|
||||
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
|
||||
mod_msg = f"The user's message does not comply with {GlobalConfig.get().openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
|
||||
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
||||
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
||||
except Exception as e:
|
||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||
_logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}')
|
||||
traceback.print_exc()
|
||||
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
|
@ -106,15 +112,8 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return response, 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
print('OAI Error:', error_msg)
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}), 400
|
||||
_logger.error(f'OAI Error: {error_msg}')
|
||||
return return_oai_invalid_request_error()
|
||||
|
||||
def build_openai_response(self, prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
|
@ -134,7 +133,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"model": running_model if GlobalConfig.get().openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
|
@ -155,7 +154,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||
if not self.parameters:
|
||||
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
|
||||
_logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}')
|
||||
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
||||
invalid_oai_err_msg = validate_oai(self.parameters)
|
||||
if invalid_oai_err_msg:
|
||||
|
|
|
@ -6,10 +6,11 @@ from uuid import uuid4
|
|||
import ujson as json
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: str, redis_key):
|
||||
|
@ -30,6 +31,7 @@ class RedisPriorityQueue:
|
|||
def __init__(self, name, db: int = 12):
|
||||
self.name = name
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
self._logger = create_logger('RedisPriorityQueue')
|
||||
|
||||
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||
# TODO: remove this when we're sure nothing strange is happening
|
||||
|
@ -41,7 +43,7 @@ class RedisPriorityQueue:
|
|||
ip_count = self.get_ip_request_count(item[1])
|
||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
||||
self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
||||
return None # reject the request
|
||||
|
||||
timestamp = time.time()
|
||||
|
@ -93,12 +95,12 @@ class RedisPriorityQueue:
|
|||
for item in self.items():
|
||||
item_data = json.loads(item)
|
||||
timestamp = item_data[-2]
|
||||
if now - timestamp > opts.backend_generate_request_timeout:
|
||||
if now - timestamp > GlobalConfig.get().backend_generate_request_timeout:
|
||||
self.redis.zrem('queue', 0, item)
|
||||
event_id = item_data[1]
|
||||
event = DataEvent(event_id)
|
||||
event.set((False, None, 'closed'))
|
||||
print('Removed timed-out item from queue:', event_id)
|
||||
self._logger.debug('Removed timed-out item from queue: {event_id}')
|
||||
|
||||
|
||||
class DataEvent:
|
||||
|
|
|
@ -4,18 +4,21 @@ from typing import Tuple, Union
|
|||
import flask
|
||||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
_logger = create_logger('RequestHandler')
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||
|
@ -103,8 +106,8 @@ class RequestHandler:
|
|||
if self.parameters and not parameters_invalid_msg:
|
||||
# Backends shouldn't check max_new_tokens, but rather things specific to their backend.
|
||||
# Let the RequestHandler do the generic checks.
|
||||
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}')
|
||||
if self.parameters.get('max_new_tokens', 0) > GlobalConfig.get().max_new_tokens:
|
||||
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {GlobalConfig.get().max_new_tokens}')
|
||||
|
||||
if prompt:
|
||||
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
||||
|
@ -223,7 +226,7 @@ class RequestHandler:
|
|||
processing_ip = 0
|
||||
|
||||
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
||||
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
|
||||
_logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
from llm_server.logging import create_logger
|
||||
|
||||
_logger = create_logger('handle_server_error')
|
||||
|
||||
|
||||
def handle_server_error(e):
|
||||
print('Internal Error:', e)
|
||||
_logger.error(f'Internal Error: {e}')
|
||||
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort
|
||||
|
@ -31,21 +31,21 @@ def generate_stats(regen: bool = False):
|
|||
'5_min': proompters_5_min,
|
||||
'24_hrs': get_distinct_ips_24h(),
|
||||
},
|
||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
|
||||
'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
|
||||
},
|
||||
'endpoints': {
|
||||
'blocking': f'https://{base_client_api}',
|
||||
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
'streaming': f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else None,
|
||||
},
|
||||
'timestamp': int(time.time()),
|
||||
'config': {
|
||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||
'api_mode': opts.frontend_api_mode
|
||||
'gatekeeper': 'none' if GlobalConfig.get().auth_required is False else 'token',
|
||||
'simultaneous_requests_per_ip': GlobalConfig.get().simultaneous_requests_per_ip,
|
||||
'api_mode': GlobalConfig.get().frontend_api_mode
|
||||
},
|
||||
'keys': {
|
||||
'openaiKeys': '∞',
|
||||
|
@ -57,12 +57,12 @@ def generate_stats(regen: bool = False):
|
|||
|
||||
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
|
||||
|
||||
if opts.show_backends:
|
||||
if GlobalConfig.get().show_backends:
|
||||
for backend_url, v in cluster_config.all().items():
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if not backend_info['online']:
|
||||
continue
|
||||
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
|
||||
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if GlobalConfig.get().show_uptime else None
|
||||
output['backends'][backend_info['hash']] = {
|
||||
'uptime': backend_uptime,
|
||||
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
|
||||
|
|
|
@ -10,15 +10,18 @@ from . import bp
|
|||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...custom_redis import redis
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...logging import create_logger
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
||||
# We solve this by splitting the routes
|
||||
|
||||
_logger = create_logger('GenerateStream')
|
||||
|
||||
|
||||
@bp.route('/v1/stream')
|
||||
@bp.route('/<model_name>/v1/stream')
|
||||
def stream(model_name=None):
|
||||
|
@ -63,7 +66,7 @@ def do_stream(ws, model_name):
|
|||
is_error=True
|
||||
)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
if not GlobalConfig.get().enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
r_headers = dict(request.headers)
|
||||
|
@ -85,7 +88,7 @@ def do_stream(ws, model_name):
|
|||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
if handler.offline:
|
||||
msg = f'{handler.selected_model} is not a valid model choice.'
|
||||
print(msg)
|
||||
_logger.debug(msg)
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
|
@ -131,7 +134,7 @@ def do_stream(ws, model_name):
|
|||
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('Stream failed to start streaming:', error_msg)
|
||||
_logger.error(f'Stream failed to start streaming: {error_msg}')
|
||||
ws.close(reason=1014, message='Request Timeout')
|
||||
return
|
||||
|
||||
|
@ -141,16 +144,16 @@ def do_stream(ws, model_name):
|
|||
try:
|
||||
last_id = '0-0' # The ID of the last entry we read.
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
_logger.error(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
|
||||
return
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
print(data['error'])
|
||||
_logger.error(f'Encountered error while streaming: {data["error"]}')
|
||||
send_err_and_quit('Encountered exception while streaming.')
|
||||
return
|
||||
elif data['new']:
|
||||
|
|
|
@ -4,9 +4,9 @@ from flask import jsonify, request
|
|||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_backends_from_model, is_valid_model
|
||||
from ...cluster.cluster_config import get_a_cluster_backend, cluster_config
|
||||
from ...config.global_config import GlobalConfig
|
||||
|
||||
|
||||
@bp.route('/v1/model', methods=['GET'])
|
||||
|
@ -31,7 +31,7 @@ def get_model(model_name=None):
|
|||
else:
|
||||
num_backends = len(get_backends_from_model(model_name))
|
||||
response = jsonify({
|
||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'result': GlobalConfig.get().manual_model_name if GlobalConfig.get().manual_model_name else model_name,
|
||||
'model_backend_count': num_backends,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
|
|
|
@ -2,6 +2,7 @@ import time
|
|||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||
|
||||
|
||||
|
@ -10,6 +11,7 @@ from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
|||
def cleaner():
|
||||
r = Redis(db=8)
|
||||
stream_info = {}
|
||||
logger = create_logger('cleaner')
|
||||
|
||||
while True:
|
||||
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
||||
|
@ -26,7 +28,7 @@ def cleaner():
|
|||
# If the size hasn't changed for 5 minutes, delete the stream
|
||||
if time.time() - stream_info[stream]['time'] >= 300:
|
||||
r.delete(stream)
|
||||
print(f"Stream '{stream}' deleted due to inactivity.")
|
||||
logger.debug(f"Stream '{stream}' deleted due to inactivity.")
|
||||
del stream_info[stream]
|
||||
|
||||
time.sleep(60)
|
||||
|
|
|
@ -8,6 +8,7 @@ import ujson
|
|||
from redis import Redis
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.logging import create_logger
|
||||
|
@ -148,12 +149,12 @@ def worker(backend_url):
|
|||
status_redis.setp(str(worker_id), None)
|
||||
|
||||
|
||||
def start_workers(cluster: dict):
|
||||
def start_workers():
|
||||
logger = create_logger('inferencer')
|
||||
i = 0
|
||||
for item in cluster:
|
||||
for _ in range(item['concurrent_gens']):
|
||||
t = threading.Thread(target=worker, args=(item['backend_url'],))
|
||||
for item in GlobalConfig.get().cluster:
|
||||
for _ in range(item.concurrent_gens):
|
||||
t = threading.Thread(target=worker, args=(item.backend_url,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
|
|
|
@ -4,6 +4,7 @@ import traceback
|
|||
import redis
|
||||
|
||||
from llm_server.database.database import do_db_log
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
||||
def db_logger():
|
||||
|
@ -16,6 +17,7 @@ def db_logger():
|
|||
r = redis.Redis(host='localhost', port=6379, db=3)
|
||||
p = r.pubsub()
|
||||
p.subscribe('database-logger')
|
||||
logger = create_logger('main_bg')
|
||||
|
||||
for message in p.listen():
|
||||
try:
|
||||
|
@ -28,4 +30,4 @@ def db_logger():
|
|||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
@ -2,15 +2,17 @@ import time
|
|||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import get_backends, cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_info
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
logger = create_logger('main_bg')
|
||||
while True:
|
||||
online, offline = get_backends()
|
||||
for backend_url in online:
|
||||
|
@ -29,12 +31,12 @@ def main_background_thread():
|
|||
if average_generation_elapsed_sec and average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
|
||||
|
||||
if opts.background_homepage_cacher:
|
||||
if GlobalConfig.get().background_homepage_cacher:
|
||||
try:
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
r = requests.get('https://' + base_client_api, timeout=5)
|
||||
except Exception as e:
|
||||
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
||||
logger.error(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
||||
|
||||
backends = priority_queue.get_backends()
|
||||
for backend_url in backends:
|
||||
|
@ -47,11 +49,11 @@ def main_background_thread():
|
|||
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
|
||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps
|
||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
|
||||
import redis as redis_redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
@ -29,7 +29,7 @@ def get_results(tag, num_tasks):
|
|||
num_results = 0
|
||||
start_time = time.time()
|
||||
while num_results < num_tasks:
|
||||
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
|
||||
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=GlobalConfig.get().openai_moderation_timeout)
|
||||
if result is None:
|
||||
break # Timeout occurred, break the loop.
|
||||
result_tag, categories = json.loads(result[1])
|
||||
|
@ -38,7 +38,7 @@ def get_results(tag, num_tasks):
|
|||
for item in categories:
|
||||
flagged_categories.add(item)
|
||||
num_results += 1
|
||||
if time.time() - start_time > opts.openai_moderation_timeout:
|
||||
if time.time() - start_time > GlobalConfig.get().openai_moderation_timeout:
|
||||
logger.warning('Timed out waiting for result from moderator')
|
||||
break
|
||||
return list(flagged_categories)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.config.config import cluster_worker_count
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
|
@ -21,14 +22,14 @@ def cache_stats():
|
|||
|
||||
def start_background():
|
||||
logger = create_logger('threader')
|
||||
start_workers(opts.cluster)
|
||||
start_workers()
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the main background thread.')
|
||||
|
||||
num_moderators = opts.cluster_workers * 3
|
||||
num_moderators = cluster_worker_count() * 3
|
||||
start_moderation_workers(num_moderators)
|
||||
logger.info(f'Started {num_moderators} moderation workers.')
|
||||
|
||||
|
@ -45,7 +46,7 @@ def start_background():
|
|||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the console logger.infoer.')
|
||||
logger.info('Started the console logger.')
|
||||
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from llm_server.helpers import resolve_path
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import Database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.logging import init_logging, create_logger
|
||||
|
||||
|
||||
def post_fork(server, worker):
|
||||
"""
|
||||
Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager.
|
||||
"""
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = script_path / '../config/config.yml'
|
||||
config_path = resolve_path(config_path)
|
||||
|
||||
success, msg = load_config(config_path)
|
||||
if not success:
|
||||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.error(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
init_logging()
|
||||
logger = create_logger('Server')
|
||||
logger.debug('Debug logging enabled.')
|
||||
|
||||
while not redis.get('daemon_started', dtype=bool):
|
||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
time.sleep(10)
|
||||
|
||||
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||
create_db()
|
||||
|
||||
logger.info('Started HTTP worker!')
|
|
@ -11,8 +11,10 @@ WorkingDirectory=/srv/server/local-llm-server
|
|||
# Sometimes the old processes aren't terminated when the service is restarted.
|
||||
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
|
||||
|
||||
# TODO: make sure gunicorn logs to stdout and logging also goes to stdout
|
||||
|
||||
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
|
||||
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'
|
||||
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn -c other/gunicorn_conf.py --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'
|
||||
|
||||
Restart=always
|
||||
RestartSec=2
|
||||
|
|
|
@ -4,7 +4,6 @@ Flask-Caching==2.0.2
|
|||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gevent~=23.9.0.post1
|
||||
PyMySQL~=1.1.0
|
||||
simplejson~=3.19.1
|
||||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
||||
|
@ -14,5 +13,7 @@ gunicorn==21.2.0
|
|||
redis==5.0.1
|
||||
ujson==5.8.0
|
||||
vllm==0.2.7
|
||||
gradio~=3.46.1
|
||||
coloredlogs~=15.0.1
|
||||
git+https://git.evulid.cc/cyberes/bison.git
|
||||
pydantic
|
||||
psycopg2-binary==2.9.9
|
85
server.py
85
server.py
|
@ -1,32 +1,13 @@
|
|||
import time
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request, Response
|
||||
|
||||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.config.config import MODE_UI_NAMES
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import flask_cache, redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.logging import init_logging
|
||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
|
@ -62,36 +43,6 @@ from llm_server.sock import init_wssocket
|
|||
# TODO: add more excluding to SYSTEM__ tokens
|
||||
# TODO: return 200 when returning formatted sillytavern error
|
||||
|
||||
try:
|
||||
import vllm
|
||||
except ModuleNotFoundError as e:
|
||||
print('Could not import vllm-gptq:', e)
|
||||
print('Please see README.md for install instructions.')
|
||||
sys.exit(1)
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
init_logging(Path(config['webserver_log_directory']) / 'server.log')
|
||||
logger = logging.getLogger('llm_server')
|
||||
|
||||
while not redis.get('daemon_started', dtype=bool):
|
||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
time.sleep(10)
|
||||
|
||||
logger.info('Started HTTP worker!')
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
@ -142,13 +93,13 @@ def home():
|
|||
# to None by the daemon.
|
||||
default_model_info['context_size'] = '-'
|
||||
|
||||
if len(config['analytics_tracking_code']):
|
||||
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
||||
if len(GlobalConfig.get().analytics_tracking_code):
|
||||
analytics_tracking_code = f"<script>\n{GlobalConfig.get().analytics_tracking_code}\n</script>"
|
||||
else:
|
||||
analytics_tracking_code = ''
|
||||
|
||||
if config['info_html']:
|
||||
info_html = config['info_html']
|
||||
if GlobalConfig.get().info_html:
|
||||
info_html = GlobalConfig.get().info_html
|
||||
else:
|
||||
info_html = ''
|
||||
|
||||
|
@ -159,25 +110,25 @@ def home():
|
|||
break
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
llm_middleware_name=GlobalConfig.get().llm_middleware_name,
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
default_model=default_model_info['model'],
|
||||
default_active_gen_workers=default_model_info['processing'],
|
||||
default_proompters_in_queue=default_model_info['queued'],
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
||||
current_model=GlobalConfig.get().manual_model_name if GlobalConfig.get().manual_model_name else None, # else running_model,
|
||||
client_api=f'https://{base_client_api}',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
|
||||
default_estimated_wait=default_estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.frontend_api_mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
|
||||
streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
|
||||
mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name,
|
||||
api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name,
|
||||
streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name,
|
||||
default_context_size=default_model_info['context_size'],
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||
enable_streaming=opts.enable_streaming,
|
||||
openai_client_api=f'https://{base_client_api}/openai/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=GlobalConfig.get().expose_openai_system_prompt,
|
||||
enable_streaming=GlobalConfig.get().enable_streaming,
|
||||
model_choices=model_choices,
|
||||
proompters_5_min=stats['stats']['proompters']['5_min'],
|
||||
proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
|
||||
|
@ -215,6 +166,6 @@ def before_app_request():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# server_startup(None)
|
||||
print('FLASK MODE - Startup complete!')
|
||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||
print('Do not run this file directly. Instead, use gunicorn:')
|
||||
print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'")
|
||||
quit(1)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<title>{{ llm_middleware_name }}</title>
|
||||
<meta content="width=device-width, initial-scale=1" name="viewport"/>
|
||||
|
@ -97,8 +96,8 @@
|
|||
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
|
||||
<p><strong>OpenAI-Compatible API URL:</strong> {{ openai_client_api }}</p>
|
||||
{% if info_html|length > 1 %}
|
||||
<br>
|
||||
{{ info_html|safe }}
|
||||
<br>
|
||||
{{ info_html|safe }}
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
|
@ -112,7 +111,8 @@
|
|||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||
{% if enable_streaming %}
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
|
||||
</li>
|
||||
{% endif %}
|
||||
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||
API key</kbd> textbox.
|
||||
|
@ -124,11 +124,12 @@
|
|||
</ol>
|
||||
</div>
|
||||
{% if openai_client_api != 'disabled' and expose_openai_system_prompt %}
|
||||
<br>
|
||||
<div id="openai">
|
||||
<strong>OpenAI-Compatible API</strong>
|
||||
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
|
||||
</div>
|
||||
<br>
|
||||
<div id="openai">
|
||||
<strong>OpenAI-Compatible API</strong>
|
||||
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You
|
||||
can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
|
||||
</div>
|
||||
{% endif %}
|
||||
<br>
|
||||
<div id="extra-info">{{ extra_info|safe }}</div>
|
||||
|
@ -147,30 +148,31 @@
|
|||
<br>
|
||||
|
||||
{% for key, value in model_choices.items() %}
|
||||
<div class="info-box">
|
||||
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
|
||||
<div class="info-box">
|
||||
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}
|
||||
worker{% else %}workers{% endif %}</span></h3>
|
||||
|
||||
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
|
||||
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
|
||||
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
|
||||
{% else %}
|
||||
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
|
||||
{% endif %}
|
||||
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
|
||||
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
|
||||
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
|
||||
{% else %}
|
||||
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
|
||||
{% endif %}
|
||||
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
|
||||
Processing: {{ value.processing }}<br>
|
||||
Queued: {{ value.queued }}<br>
|
||||
</p>
|
||||
<p>
|
||||
<strong>Client API URL:</strong> {{ value.client_api }}<br>
|
||||
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
|
||||
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
|
||||
</p>
|
||||
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
|
||||
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
|
||||
</div>
|
||||
<br>
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
|
||||
Processing: {{ value.processing }}<br>
|
||||
Queued: {{ value.queued }}<br>
|
||||
</p>
|
||||
<p>
|
||||
<strong>Client API URL:</strong> {{ value.client_api }}<br>
|
||||
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
|
||||
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
|
||||
</p>
|
||||
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
|
||||
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
|
||||
</div>
|
||||
<br>
|
||||
{% endfor %}
|
||||
</div>
|
||||
<div class="footer">
|
||||
|
|
Loading…
Reference in New Issue