redo config

This commit is contained in:
Cyberes 2024-05-07 12:20:53 -06:00
parent ff82add09e
commit ee9a0d4858
39 changed files with 363 additions and 318 deletions

View File

@ -8,9 +8,12 @@ from pathlib import Path
from redis import Redis from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends from llm_server.config.global_config import GlobalConfig
from llm_server.config.load import load_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.logging import create_logger, logging_info, init_logging from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background from llm_server.workers.threader import start_background
@ -39,19 +42,23 @@ if __name__ == "__main__":
Redis().flushall() Redis().flushall()
logger.info('Flushed Redis.') logger.info('Flushed Redis.')
success, config, msg = load_config(config_path) success, msg = load_config(config_path)
if not success: if not success:
logger.info(f'Failed to load config: {msg}') logger.info(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
create_db() create_db()
cluster_config.clear() cluster_config.clear()
cluster_config.load(parse_backends(config)) cluster_config.load()
logger.info('Loading backend stats...') logger.info('Loading backend stats...')
generate_stats(regen=True) generate_stats(regen=True)
if GlobalConfig.get().load_num_prompts:
redis.set('proompts', get_number_of_rows('prompts'))
start_background() start_background()
# Give some time for the background threads to get themselves ready to go. # Give some time for the background threads to get themselves ready to go.

View File

@ -1,8 +1,8 @@
import numpy as np import numpy as np
from llm_server import opts
from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config
from llm_server.cluster.stores import redis_running_models from llm_server.cluster.stores import redis_running_models
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.llm.info import get_info from llm_server.llm.info import get_info
@ -108,8 +108,8 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
model_choices[model] = { model_choices[model] = {
'model': model, 'model': model,
'client_api': f'https://{base_client_api}/{model}', 'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if GlobalConfig.get().enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled', 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled',
'backend_count': len(b), 'backend_count': len(b),
'estimated_wait': estimated_wait_sec, 'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue, 'queued': proompters_in_queue,

View File

@ -2,9 +2,9 @@ import hashlib
import pickle import pickle
import traceback import traceback
from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models from llm_server.cluster.stores import redis_running_models
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import RedisCustom from llm_server.custom_redis import RedisCustom
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.helpers.model import estimate_model_size from llm_server.routes.helpers.model import estimate_model_size
@ -26,8 +26,13 @@ class RedisClusterStore:
def clear(self): def clear(self):
self.config_redis.flush() self.config_redis.flush()
def load(self, config: dict): def load(self):
for k, v in config.items(): stuff = {}
for item in GlobalConfig.get().cluster:
backend_url = item.backend_url.strip('/')
item.backend_url = backend_url
stuff[backend_url] = item
for k, v in stuff.items():
self.add_backend(k, v) self.add_backend(k, v)
def add_backend(self, name: str, values: dict): def add_backend(self, name: str, values: dict):
@ -92,7 +97,7 @@ def get_backends():
result[k] = {'status': status, 'priority': priority} result[k] = {'status': status, 'priority': priority}
try: try:
if not opts.prioritize_by_size: if not GlobalConfig.get().prioritize_by_size:
online_backends = sorted( online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']), ((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'], key=lambda kv: -kv[1]['priority'],

View File

@ -1,81 +1,14 @@
import yaml from llm_server.config.global_config import GlobalConfig
def cluster_worker_count():
count = 0
for item in GlobalConfig.get().cluster:
count += item['concurrent_gens']
return count
config_default_vars = {
'log_prompts': False,
'auth_required': False,
'frontend_api_client': '',
'verify_ssl': True,
'load_num_prompts': False,
'show_num_prompts': True,
'show_uptime': True,
'analytics_tracking_code': '',
'average_generation_time_mode': 'database',
'info_html': None,
'show_total_output_tokens': True,
'simultaneous_requests_per_ip': 3,
'max_new_tokens': 500,
'manual_model_name': False,
'enable_streaming': True,
'enable_openi_compatible_backend': True,
'openai_api_key': None,
'expose_openai_system_prompt': True,
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
'http_host': None,
'admin_token': None,
'openai_expose_our_model': False,
'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5,
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True,
'netdata_root': None,
'show_backends': True,
'background_homepage_cacher': True,
'openai_moderation_timeout': 5,
'prioritize_by_size': False
}
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = { mode_ui_names = {
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
} }
class ConfigLoader:
"""
default_vars = {"var1": "default1", "var2": "default2"}
required_vars = ["var1", "var3"]
config_loader = ConfigLoader("config.yaml", default_vars, required_vars)
config = config_loader.load_config()
"""
def __init__(self, config_file, default_vars=None, required_vars=None):
self.config_file = config_file
self.default_vars = default_vars if default_vars else {}
self.required_vars = required_vars if required_vars else []
self.config = {}
def load_config(self) -> (bool, str | None, str | None):
with open(self.config_file, 'r') as stream:
try:
self.config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
return False, None, exc
if self.config is None:
# Handle empty file
self.config = {}
# Set default variables if they are not present in the config file
for var, default_value in self.default_vars.items():
if var not in self.config:
self.config[var] = default_value
# Check if required variables are present in the config file
for var in self.required_vars:
if var not in self.config:
return False, None, f'Required variable "{var}" is missing from the config file'
return True, self.config, None

View File

@ -0,0 +1,15 @@
from llm_server.config.model import ConfigModel
class GlobalConfig:
__config_model: ConfigModel = None
@classmethod
def initalize(cls, config: ConfigModel):
if cls.__config_model is not None:
raise Exception('Config is already initialised')
cls.__config_model = config
@classmethod
def get(cls):
return cls.__config_model

View File

@ -1,94 +1,86 @@
import re import re
import sys import sys
from pathlib import Path
import openai import openai
from bison import bison, Option, ListOption, Scheme
import llm_server import llm_server
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.config.model import ConfigModel
from llm_server.config.scheme import config_scheme
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.conn import Database
from llm_server.database.database import get_number_of_rows
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.queue import PriorityQueue from llm_server.routes.queue import PriorityQueue
_logger = create_logger('config') _logger = create_logger('config')
def load_config(config_path): def validate_config(config: bison.Bison):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) def do(v, scheme: Scheme = None):
success, config, msg = config_loader.load_config() if isinstance(v, Option) and v.choices is None:
if not success: if not isinstance(config.config[v.name], v.type):
return success, config, msg raise ValueError(f'"{v.name}" must be type {v.type}. Current value: "{config.config[v.name]}"')
elif isinstance(v, Option) and v.choices is not None:
if config.config[v.name] not in v.choices:
raise ValueError(f'"{v.name}" must be one of {v.choices}. Current value: "{config.config[v.name]}"')
elif isinstance(v, ListOption):
if isinstance(config.config[v.name], list):
for item in config.config[v.name]:
do(item, v.member_scheme)
elif isinstance(config.config[v.name], dict):
for kk, vv in config.config[v.name].items():
scheme_dict = v.member_scheme.flatten()
if not isinstance(vv, scheme_dict[kk].type):
raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"')
elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None:
if vv not in scheme_dict[kk].choices:
raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"')
elif isinstance(v, dict) and scheme is not None:
scheme_dict = scheme.flatten()
for kk, vv in v.items():
if not isinstance(vv, scheme_dict[kk].type):
raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"')
elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None:
if vv not in scheme_dict[kk].choices:
raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"')
# TODO: this is atrocious for k, v in config_scheme.flatten().items():
opts.auth_required = config['auth_required'] do(v)
opts.log_prompts = config['log_prompts']
opts.frontend_api_client = config['frontend_api_client']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_expose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
opts.background_homepage_cacher = config['background_homepage_cacher']
opts.openai_moderation_timeout = config['openai_moderation_timeout']
opts.frontend_api_mode = config['frontend_api_mode']
opts.prioritize_by_size = config['prioritize_by_size']
# Scale the number of workers.
for item in config['cluster']:
opts.cluster_workers += item['concurrent_gens']
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) def load_config(config_path: Path):
config = bison.Bison(scheme=config_scheme)
config.config_name = 'config'
config.add_config_paths(str(config_path.parent))
config.parse()
if opts.openai_expose_our_model and not opts.openai_api_key: try:
validate_config(config)
except ValueError as e:
return False, str(e)
config_model = ConfigModel(**config.config)
GlobalConfig.initalize(config_model)
if not (0 < GlobalConfig.get().mysql.maxconn <= 32):
return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"'
openai.api_key = GlobalConfig.get().openai_api_key
llm_server.routes.queue.priority_queue = PriorityQueue(set([x.backend_url for x in config_model.cluster]))
if GlobalConfig.get().openai_expose_our_model and not GlobalConfig.get().openai_api_key:
_logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.') _logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1) sys.exit(1)
opts.verify_ssl = config['verify_ssl'] if not GlobalConfig.get().verify_ssl:
if not opts.verify_ssl:
import urllib3 import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['http_host']: if GlobalConfig.get().http_host:
http_host = re.sub(r'https?://', '', config["http_host"]) http_host = re.sub(r'https?://', '', config["http_host"])
redis.set('http_host', http_host) redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') redis.set('base_client_api', f'{http_host}/{GlobalConfig.get().frontend_api_client.strip("/")}')
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database']) return True, None
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,74 @@
from enum import Enum
from typing import Union, List
from pydantic import BaseModel
class ConfigClusterMode(str, Enum):
vllm = 'vllm'
class ConfigCluser(BaseModel):
backend_url: str
concurrent_gens: int
mode: ConfigClusterMode
priority: int
class ConfigFrontendApiModes(str, Enum):
ooba = 'ooba'
class ConfigMysql(BaseModel):
host: str
username: str
password: str
database: str
maxconn: int
class ConfigAvgGenTimeModes(str, Enum):
database = 'database'
minute = 'minute'
class ConfigModel(BaseModel):
frontend_api_mode: ConfigFrontendApiModes
cluster: List[ConfigCluser]
prioritize_by_size: bool
admin_token: Union[str, None]
mysql: ConfigMysql
http_host: str
webserver_log_directory: str
include_system_tokens_in_stats: bool
background_homepage_cacher: bool
max_new_tokens: int
enable_streaming: int
show_backends: bool
log_prompts: bool
verify_ssl: bool
auth_required: bool
simultaneous_requests_per_ip: int
max_queued_prompts_per_ip: int
llm_middleware_name: str
analytics_tracking_code: Union[str, None]
info_html: Union[str, None]
enable_openi_compatible_backend: bool
openai_api_key: Union[str, None]
expose_openai_system_prompt: bool
openai_expose_our_model: bool
openai_force_no_hashes: bool
openai_moderation_enabled: bool
openai_moderation_timeout: int
openai_moderation_scan_last_n: int
openai_org_name: str
openai_silent_trim: bool
frontend_api_client: str
average_generation_time_mode: ConfigAvgGenTimeModes
show_num_prompts: bool
show_uptime: bool
show_total_output_tokens: bool
show_backend_info: bool
load_num_prompts: bool
manual_model_name: Union[str, None]
backend_request_timeout: int

View File

@ -0,0 +1,59 @@
from typing import Union
import bison
from llm_server.opts import default_openai_system_prompt
config_scheme = bison.Scheme(
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
bison.ListOption('cluster', member_scheme=bison.Scheme(
bison.Option('backend_url', field_type=str),
bison.Option('concurrent_gens', field_type=int),
bison.Option('mode', choices=['vllm'], field_type=str),
bison.Option('priority', field_type=int),
)),
bison.Option('prioritize_by_size', default=True, field_type=bool),
bison.Option('admin_token', default=None, field_type=Union[str, None]),
bison.ListOption('mysql', member_scheme=bison.Scheme(
bison.Option('host', field_type=str),
bison.Option('username', field_type=str),
bison.Option('password', field_type=str),
bison.Option('database', field_type=str),
bison.Option('maxconn', field_type=int)
)),
bison.Option('http_host', default='', field_type=str),
bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str),
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
bison.Option('background_homepage_cacher', default=True, field_type=bool),
bison.Option('max_new_tokens', default=500, field_type=int),
bison.Option('enable_streaming', default=True, field_type=bool),
bison.Option('show_backends', default=True, field_type=bool),
bison.Option('log_prompts', default=True, field_type=bool),
bison.Option('verify_ssl', default=False, field_type=bool),
bison.Option('auth_required', default=False, field_type=bool),
bison.Option('simultaneous_requests_per_ip', default=1, field_type=int),
bison.Option('max_queued_prompts_per_ip', default=1, field_type=int),
bison.Option('llm_middleware_name', default='LocalLLM', field_type=str),
bison.Option('analytics_tracking_code', default=None, field_type=Union[str, None]),
bison.Option('info_html', default=None, field_type=Union[str, None]),
bison.Option('enable_openi_compatible_backend', default=True, field_type=bool),
bison.Option('openai_api_key', default=None, field_type=Union[str, None]),
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
bison.Option('openai_expose_our_model', default='', field_type=bool),
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str),
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
bison.Option('openai_moderation_timeout', default=5, field_type=int),
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
bison.Option('openai_org_name', default='OpenAI', field_type=str),
bison.Option('openai_silent_trim', default=True, field_type=bool),
bison.Option('frontend_api_client', default='/api', field_type=str),
bison.Option('average_generation_time_mode', default='database', choices=['database', 'minute'], field_type=str),
bison.Option('show_num_prompts', default=True, field_type=bool),
bison.Option('show_uptime', default=True, field_type=bool),
bison.Option('show_total_output_tokens', default=True, field_type=bool),
bison.Option('show_backend_info', default=True, field_type=bool),
bison.Option('load_num_prompts', default=True, field_type=bool),
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
bison.Option('backend_request_timeout', default=30, field_type=int)
)

View File

@ -5,7 +5,7 @@ class Database:
__connection_pool = None __connection_pool = None
@classmethod @classmethod
def initialise(cls, maxconn, **kwargs): def initialise(cls, maxconn: int, **kwargs):
if cls.__connection_pool is not None: if cls.__connection_pool is not None:
raise Exception('Database connection pool is already initialised') raise Exception('Database connection pool is already initialised')
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn, cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,

View File

@ -3,8 +3,8 @@ import time
import traceback import traceback
from typing import Union from typing import Union
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.database.conn import CursorFromConnectionFromPool from llm_server.database.conn import CursorFromConnectionFromPool
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
@ -38,10 +38,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
if is_error: if is_error:
gen_time = None gen_time = None
if not opts.log_prompts: if not GlobalConfig.get().log_prompts:
prompt = None prompt = None
if not opts.log_prompts and not is_error: if not GlobalConfig.get().log_prompts and not is_error:
# TODO: test and verify this works as expected # TODO: test and verify this works as expected
response = None response = None
@ -75,13 +75,13 @@ def is_valid_api_key(api_key):
def is_api_key_moderated(api_key): def is_api_key_moderated(api_key):
if not api_key: if not api_key:
return opts.openai_moderation_enabled return GlobalConfig.get().openai_moderation_enabled
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
if row is not None: if row is not None:
return bool(row[0]) return bool(row[0])
return opts.openai_moderation_enabled return GlobalConfig.get().openai_moderation_enabled
def get_number_of_rows(table_name): def get_number_of_rows(table_name):
@ -160,7 +160,7 @@ def increment_token_uses(token):
def get_token_ratelimit(token): def get_token_ratelimit(token):
priority = 9990 priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip simultaneous_ip = GlobalConfig.get().simultaneous_requests_per_ip
if token: if token:
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))

View File

@ -7,7 +7,7 @@ from typing import Union
import simplejson as json import simplejson as json
from flask import make_response from flask import make_response
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
@ -68,4 +68,4 @@ def auto_set_base_client_api(request):
return return
else: else:
redis.set('http_host', host) redis.set('http_host', host)
redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}') redis.set('base_client_api', f'{host}/{GlobalConfig.get().frontend_api_client.strip("/")}')

View File

@ -1,19 +1,19 @@
import requests import requests
from llm_server import opts from llm_server.config.global_config import GlobalConfig
def get_running_model(backend_url: str, mode: str): def get_running_model(backend_url: str, mode: str):
if mode == 'ooba': if mode == 'ooba':
try: try:
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['result'], None return r_json['result'], None
except Exception as e: except Exception as e:
return False, e return False, e
elif mode == 'vllm': elif mode == 'vllm':
try: try:
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['model'], None return r_json['model'], None
except Exception as e: except Exception as e:
@ -28,7 +28,7 @@ def get_info(backend_url: str, mode: str):
# raise NotImplementedError # raise NotImplementedError
elif mode == 'vllm': elif mode == 'vllm':
try: try:
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout) r = requests.get(f'{backend_url}/info', verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_request_timeout)
j = r.json() j = r.json()
except Exception as e: except Exception as e:
return {} return {}

View File

@ -5,12 +5,12 @@ import traceback
import requests import requests
from llm_server import opts from llm_server.config.global_config import GlobalConfig
def generate(json_data: dict): def generate(json_data: dict):
try: try:
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
except Exception as e: except Exception as e:

View File

@ -1,6 +1,6 @@
import requests import requests
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.logging import create_logger from llm_server.logging import create_logger
_logger = create_logger('moderation') _logger = create_logger('moderation')
@ -9,7 +9,7 @@ _logger = create_logger('moderation')
def check_moderation_endpoint(prompt: str): def check_moderation_endpoint(prompt: str):
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}", 'Authorization': f"Bearer {GlobalConfig.get().openai_api_key}",
} }
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200: if response.status_code != 200:

View File

@ -1,6 +1,6 @@
from flask import jsonify from flask import jsonify
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.logging import create_logger from llm_server.logging import create_logger
_logger = create_logger('oai_to_vllm') _logger = create_logger('oai_to_vllm')
@ -14,7 +14,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
request_json_body['stop'] = [request_json_body['stop']] request_json_body['stop'] = [request_json_body['stop']]
if stop_hashes: if stop_hashes:
if opts.openai_force_no_hashes: if GlobalConfig.get().openai_force_no_hashes:
request_json_body['stop'].append('###') request_json_body['stop'].append('###')
else: else:
# TODO: make stopping strings a configurable # TODO: make stopping strings a configurable
@ -30,7 +30,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if mode == 'vllm' and request_json_body.get('top_p') == 0: if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01 request_json_body['top_p'] = 0.01
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens) request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), GlobalConfig.get().max_new_tokens)
if request_json_body['max_tokens'] == 0: if request_json_body['max_tokens'] == 0:
# We don't want to set any defaults here. # We don't want to set any defaults here.
del request_json_body['max_tokens'] del request_json_body['max_tokens']

View File

@ -7,7 +7,7 @@ from typing import Dict, List
import tiktoken import tiktoken
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
@ -85,7 +85,7 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
def transform_messages_to_prompt(oai_messages): def transform_messages_to_prompt(oai_messages):
try: try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}'
for msg in oai_messages: for msg in oai_messages:
if 'content' not in msg.keys() or 'role' not in msg.keys(): if 'content' not in msg.keys() or 'role' not in msg.keys():
return False return False

View File

@ -4,7 +4,7 @@ This file is used by the worker that processes requests.
import requests import requests
from llm_server import opts from llm_server.config.global_config import GlobalConfig
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed
@ -25,7 +25,7 @@ def transform_prompt_to_text(prompt: list):
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10): def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try: try:
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out') # print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
@ -41,7 +41,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
def generate(json_data: dict, cluster_backend, timeout: int = None): def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'): if json_data.get('stream'):
try: try:
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout)
except Exception as e: except Exception as e:
return False return False
else: else:

View File

@ -1,7 +1,3 @@
import requests
from llm_server import opts
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p> vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong> <strong>Supported Parameters:</strong>
<ul> <ul>

View File

@ -3,8 +3,8 @@ import concurrent.futures
import requests import requests
import tiktoken import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.logging import create_logger from llm_server.logging import create_logger
@ -32,7 +32,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
# Define a function to send a chunk to the server # Define a function to send a chunk to the server
def send_chunk(chunk): def send_chunk(chunk):
try: try:
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
j = r.json() j = r.json()
return j['length'] return j['length']
except Exception as e: except Exception as e:

View File

@ -1,45 +1,10 @@
# Read-only global variables # Read-only global variables
# Uppercase variables are read-only globals. default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here # cluster = {}
frontend_api_mode = 'ooba'
max_new_tokens = 500 REDIS_STREAM_TIMEOUT = 25000
auth_required = False
log_prompts = False
frontend_api_client = ''
verify_ssl = True
show_num_prompts = True
show_uptime = True
average_generation_time_mode = 'database'
show_total_output_tokens = True
netdata_root = None
simultaneous_requests_per_ip = 3
manual_model_name = None
llm_middleware_name = ''
enable_openi_compatible_backend = True
openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
expose_openai_system_prompt = True
enable_streaming = True
openai_api_key = None
backend_request_timeout = 30
backend_generate_request_timeout = 95
admin_token = None
openai_expose_our_model = False
openai_force_no_hashes = True
include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}
show_backends = True
background_homepage_cacher = True
openai_moderation_timeout = 5
prioritize_by_size = False
cluster_workers = 0
redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -3,7 +3,7 @@ from functools import wraps
import basicauth import basicauth
from flask import Response, request from flask import Response, request
from llm_server import opts from llm_server.config.global_config import GlobalConfig
def parse_token(input_token): def parse_token(input_token):
@ -21,11 +21,11 @@ def parse_token(input_token):
def check_auth(token): def check_auth(token):
if not opts.admin_token: if not GlobalConfig.get().admin_token:
# The admin token is not set/enabled. # The admin token is not set/enabled.
# Default: deny all. # Default: deny all.
return False return False
return parse_token(token) == opts.admin_token return parse_token(token) == GlobalConfig.get().admin_token
def authenticate(): def authenticate():

View File

@ -1,14 +1,14 @@
import simplejson as json
import traceback import traceback
from functools import wraps from functools import wraps
from typing import Union from typing import Union
import flask import flask
import requests import requests
import simplejson as json
from flask import Request, make_response from flask import Request, make_response
from flask import jsonify, request from flask import jsonify, request
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.database.database import is_valid_api_key from llm_server.database.database import is_valid_api_key
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
@ -34,7 +34,7 @@ def cache_control(seconds):
# response = require_api_key() # response = require_api_key()
# ^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^
# File "/srv/server/local-llm-server/llm_server/routes/helpers/http.py", line 50, in require_api_key # File "/srv/server/local-llm-server/llm_server/routes/helpers/http.py", line 50, in require_api_key
# if token.startswith('SYSTEM__') or opts.auth_required: # if token.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
# ^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^
# AttributeError: 'NoneType' object has no attribute 'startswith' # AttributeError: 'NoneType' object has no attribute 'startswith'
@ -50,14 +50,14 @@ def require_api_key(json_body: dict = None):
request_json = None request_json = None
if 'X-Api-Key' in request.headers: if 'X-Api-Key' in request.headers:
api_key = request.headers['X-Api-Key'] api_key = request.headers['X-Api-Key']
if api_key.startswith('SYSTEM__') or opts.auth_required: if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
if is_valid_api_key(api_key): if is_valid_api_key(api_key):
return return
else: else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
elif 'Authorization' in request.headers: elif 'Authorization' in request.headers:
token = parse_token(request.headers['Authorization']) token = parse_token(request.headers['Authorization'])
if (token and token.startswith('SYSTEM__')) or opts.auth_required: if (token and token.startswith('SYSTEM__')) or GlobalConfig.get().auth_required:
if is_valid_api_key(token): if is_valid_api_key(token):
return return
else: else:
@ -65,13 +65,13 @@ def require_api_key(json_body: dict = None):
else: else:
try: try:
# Handle websockets # Handle websockets
if opts.auth_required and not request_json: if GlobalConfig.get().auth_required and not request_json:
# If we didn't get any valid JSON, deny. # If we didn't get any valid JSON, deny.
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
if request_json and request_json.get('X-API-KEY'): if request_json and request_json.get('X-API-KEY'):
api_key = request_json.get('X-API-KEY') api_key = request_json.get('X-API-KEY')
if api_key.startswith('SYSTEM__') or opts.auth_required: if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
if is_valid_api_key(api_key): if is_valid_api_key(api_key):
return return
else: else:

View File

@ -3,7 +3,8 @@ from typing import Tuple
import flask import flask
from flask import jsonify, request from flask import jsonify, request
from llm_server import messages, opts from llm_server import messages
from llm_server.config.global_config import GlobalConfig
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
@ -39,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
return backend_response return backend_response
def handle_ratelimited(self, do_log: bool = True): def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {GlobalConfig.get().simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
if do_log: if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)

View File

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
from ..request_handler import before_request from ..request_handler import before_request
from ... import opts from ...config.global_config import GlobalConfig
from ...logging import create_logger from ...logging import create_logger
_logger = create_logger('OpenAI') _logger = create_logger('OpenAI')
@ -13,7 +13,7 @@ openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request @openai_bp.before_request
@openai_model_bp.before_request @openai_model_bp.before_request
def before_oai_request(): def before_oai_request():
if not opts.enable_openi_compatible_backend: if not GlobalConfig.get().enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401 return 'The OpenAI-compatible backend is disabled.', 401
return before_request() return before_request()

View File

@ -11,7 +11,7 @@ from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ... import opts from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -41,7 +41,7 @@ def openai_chat_completions(model_name=None):
traceback.print_exc() traceback.print_exc()
return 'Internal server error', 500 return 'Internal server error', 500
else: else:
if not opts.enable_streaming: if not GlobalConfig.get().enable_streaming:
return 'Streaming disabled', 403 return 'Streaming disabled', 403
invalid_oai_err_msg = validate_oai(handler.request_json_body) invalid_oai_err_msg = validate_oai(handler.request_json_body)
@ -57,7 +57,7 @@ def openai_chat_completions(model_name=None):
**handler.parameters **handler.parameters
} }
if opts.openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else: else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
@ -95,7 +95,7 @@ def openai_chat_completions(model_name=None):
try: try:
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to # Need to do this before we enter generate() since we want to be able to
@ -112,9 +112,9 @@ def openai_chat_completions(model_name=None):
try: try:
last_id = '0-0' last_id = '0-0'
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
if not stream_data: if not stream_data:
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.debug(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:

View File

@ -11,7 +11,7 @@ from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ... import opts from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
@ -43,7 +43,7 @@ def openai_completions(model_name=None):
return invalid_oai_err_msg return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else: else:
# The handle_request() call below will load the prompt so we don't have # The handle_request() call below will load the prompt so we don't have
@ -69,7 +69,7 @@ def openai_completions(model_name=None):
"id": f"cmpl-{generate_oai_string(30)}", "id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), "model": running_model if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model'),
"choices": [ "choices": [
{ {
"text": output, "text": output,
@ -91,7 +91,7 @@ def openai_completions(model_name=None):
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] # response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200 return response, 200
else: else:
if not opts.enable_streaming: if not GlobalConfig.get().enable_streaming:
return 'Streaming disabled', 403 return 'Streaming disabled', 403
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
@ -109,7 +109,7 @@ def openai_completions(model_name=None):
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
if opts.openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']] handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt: if not handler.prompt:
# Prevent issues on the backend. # Prevent issues on the backend.
@ -142,7 +142,7 @@ def openai_completions(model_name=None):
try: try:
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
@ -157,9 +157,9 @@ def openai_completions(model_name=None):
try: try:
last_id = '0-0' last_id = '0-0'
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
if not stream_data: if not stream_data:
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.debug(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:

View File

@ -1,15 +1,15 @@
from flask import Response from flask import Response
from . import openai_bp
from llm_server.custom_redis import flask_cache from llm_server.custom_redis import flask_cache
from ... import opts from . import openai_bp
from ...config.global_config import GlobalConfig
@openai_bp.route('/prompt', methods=['GET']) @openai_bp.route('/prompt', methods=['GET'])
@flask_cache.cached(timeout=2678000, query_string=True) @flask_cache.cached(timeout=2678000, query_string=True)
def get_openai_info(): def get_openai_info():
if opts.expose_openai_system_prompt: if GlobalConfig.get().expose_openai_system_prompt:
resp = Response(opts.openai_system_prompt) resp = Response(GlobalConfig.get().openai_system_prompt)
resp.headers['Content-Type'] = 'text/plain' resp.headers['Content-Type'] = 'text/plain'
return resp, 200 return resp, 200
else: else:

View File

@ -6,8 +6,8 @@ from flask import jsonify
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp from . import openai_bp
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts
from ...cluster.cluster_config import get_a_cluster_backend, cluster_config from ...cluster.cluster_config import get_a_cluster_backend, cluster_config
from ...config.global_config import GlobalConfig
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
from ...llm.openai.transform import generate_oai_string from ...llm.openai.transform import generate_oai_string
@ -29,12 +29,12 @@ def openai_list_models():
"data": oai "data": oai
} }
# TODO: verify this works # TODO: verify this works
if opts.openai_expose_our_model: if GlobalConfig.get().openai_expose_our_model:
r["data"].insert(0, { r["data"].insert(0, {
"id": running_model, "id": running_model,
"object": "model", "object": "model",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name, "owned_by": GlobalConfig.get().llm_middleware_name,
"permission": [ "permission": [
{ {
"id": running_model, "id": running_model,
@ -60,9 +60,9 @@ def openai_list_models():
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS) @flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models(): def fetch_openai_models():
if opts.openai_api_key: if GlobalConfig.get().openai_api_key:
try: try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10) response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {GlobalConfig.get().openai_api_key}"}, timeout=10)
j = response.json()['data'] j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll # The "modelperm" string appears to be user-specific, so we'll

View File

@ -8,8 +8,8 @@ from uuid import uuid4
import flask import flask
from flask import Response, jsonify, make_response from flask import Response, jsonify, make_response
from llm_server import opts
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
@ -35,7 +35,7 @@ class OpenAIRequestHandler(RequestHandler):
_logger.error(f'OAI is offline: {msg}') _logger.error(f'OAI is offline: {msg}')
return self.handle_error(msg) return self.handle_error(msg)
if opts.openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else: else:
oai_messages = self.request.json['messages'] oai_messages = self.request.json['messages']
@ -58,20 +58,20 @@ class OpenAIRequestHandler(RequestHandler):
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token): if GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token):
try: try:
# Gather the last message from the user and all preceding system messages # Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy() msg_l = self.request.json['messages'].copy()
msg_l.reverse() msg_l.reverse()
tag = uuid4() tag = uuid4()
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n) num_to_check = min(len(msg_l), GlobalConfig.get().openai_moderation_scan_last_n)
for i in range(num_to_check): for i in range(num_to_check):
add_moderation_task(msg_l[i]['content'], tag) add_moderation_task(msg_l[i]['content'], tag)
flagged_categories = get_results(tag, num_to_check) flagged_categories = get_results(tag, num_to_check)
if len(flagged_categories): if len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies." mod_msg = f"The user's message does not comply with {GlobalConfig.get().openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = transform_messages_to_prompt(self.request.json['messages']) self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e: except Exception as e:
@ -137,7 +137,7 @@ class OpenAIRequestHandler(RequestHandler):
"id": f"chatcmpl-{generate_oai_string(30)}", "id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model, "model": running_model if GlobalConfig.get().openai_expose_our_model else model,
"choices": [{ "choices": [{
"index": 0, "index": 0,
"message": { "message": {

View File

@ -6,8 +6,8 @@ from uuid import uuid4
import ujson as json import ujson as json
from redis import Redis from redis import Redis
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit from llm_server.database.database import get_token_ratelimit
from llm_server.logging import create_logger from llm_server.logging import create_logger
@ -95,7 +95,7 @@ class RedisPriorityQueue:
for item in self.items(): for item in self.items():
item_data = json.loads(item) item_data = json.loads(item)
timestamp = item_data[-2] timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout: if now - timestamp > GlobalConfig.get().backend_generate_request_timeout:
self.redis.zrem('queue', 0, item) self.redis.zrem('queue', 0, item)
event_id = item_data[1] event_id = item_data[1]
event = DataEvent(event_id) event = DataEvent(event_id)

View File

@ -4,8 +4,8 @@ from typing import Tuple, Union
import flask import flask
from flask import Response, request from flask import Response, request
from llm_server import opts
from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
@ -106,8 +106,8 @@ class RequestHandler:
if self.parameters and not parameters_invalid_msg: if self.parameters and not parameters_invalid_msg:
# Backends shouldn't check max_new_tokens, but rather things specific to their backend. # Backends shouldn't check max_new_tokens, but rather things specific to their backend.
# Let the RequestHandler do the generic checks. # Let the RequestHandler do the generic checks.
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens: if self.parameters.get('max_new_tokens', 0) > GlobalConfig.get().max_new_tokens:
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}') invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {GlobalConfig.get().max_new_tokens}')
if prompt: if prompt:
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt) prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)

View File

@ -1,9 +1,9 @@
import time import time
from datetime import datetime from datetime import datetime
from llm_server import opts
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort from llm_server.helpers import deep_sort
@ -31,21 +31,21 @@ def generate_stats(regen: bool = False):
'5_min': proompters_5_min, '5_min': proompters_5_min,
'24_hrs': get_distinct_ips_24h(), '24_hrs': get_distinct_ips_24h(),
}, },
'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
# 'estimated_avg_tps': estimated_avg_tps, # 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, 'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None, 'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
}, },
'endpoints': { 'endpoints': {
'blocking': f'https://{base_client_api}', 'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, 'streaming': f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else None,
}, },
'timestamp': int(time.time()), 'timestamp': int(time.time()),
'config': { 'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token', 'gatekeeper': 'none' if GlobalConfig.get().auth_required is False else 'token',
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, 'simultaneous_requests_per_ip': GlobalConfig.get().simultaneous_requests_per_ip,
'api_mode': opts.frontend_api_mode 'api_mode': GlobalConfig.get().frontend_api_mode
}, },
'keys': { 'keys': {
'openaiKeys': '', 'openaiKeys': '',
@ -57,12 +57,12 @@ def generate_stats(regen: bool = False):
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself # TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
if opts.show_backends: if GlobalConfig.get().show_backends:
for backend_url, v in cluster_config.all().items(): for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url) backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']: if not backend_info['online']:
continue continue
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if GlobalConfig.get().show_uptime else None
output['backends'][backend_info['hash']] = { output['backends'][backend_info['hash']] = {
'uptime': backend_uptime, 'uptime': backend_uptime,
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1), 'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),

View File

@ -10,7 +10,7 @@ from . import bp
from ..helpers.http import require_api_key, validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ... import opts from ...config.global_config import GlobalConfig
from ...custom_redis import redis from ...custom_redis import redis
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...logging import create_logger from ...logging import create_logger
@ -66,7 +66,7 @@ def do_stream(ws, model_name):
is_error=True is_error=True
) )
if not opts.enable_streaming: if not GlobalConfig.get().enable_streaming:
return 'Streaming disabled', 403 return 'Streaming disabled', 403
r_headers = dict(request.headers) r_headers = dict(request.headers)
@ -144,9 +144,9 @@ def do_stream(ws, model_name):
try: try:
last_id = '0-0' # The ID of the last entry we read. last_id = '0-0' # The ID of the last entry we read.
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().REDIS_STREAM_TIMEOUT)
if not stream_data: if not stream_data:
_logger.error(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.error(f"No message received in {GlobalConfig.get().REDIS_STREAM_TIMEOUT / 1000} seconds, closing stream.")
return return
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:

View File

@ -4,9 +4,9 @@ from flask import jsonify, request
from llm_server.custom_redis import flask_cache from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from ... import opts
from ...cluster.backend import get_backends_from_model, is_valid_model from ...cluster.backend import get_backends_from_model, is_valid_model
from ...cluster.cluster_config import get_a_cluster_backend, cluster_config from ...cluster.cluster_config import get_a_cluster_backend, cluster_config
from ...config.global_config import GlobalConfig
@bp.route('/v1/model', methods=['GET']) @bp.route('/v1/model', methods=['GET'])
@ -31,7 +31,7 @@ def get_model(model_name=None):
else: else:
num_backends = len(get_backends_from_model(model_name)) num_backends = len(get_backends_from_model(model_name))
response = jsonify({ response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name, 'result': GlobalConfig.get().manual_model_name if GlobalConfig.get().manual_model_name else model_name,
'model_backend_count': num_backends, 'model_backend_count': num_backends,
'timestamp': int(time.time()) 'timestamp': int(time.time())
}), 200 }), 200

View File

@ -2,8 +2,8 @@ import time
import requests import requests
from llm_server import opts
from llm_server.cluster.cluster_config import get_backends, cluster_config from llm_server.cluster.cluster_config import get_backends, cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info from llm_server.llm.info import get_info
@ -31,7 +31,7 @@ def main_background_thread():
if average_generation_elapsed_sec and average_output_tokens: if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps) cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
if opts.background_homepage_cacher: if GlobalConfig.get().background_homepage_cacher:
try: try:
base_client_api = redis.get('base_client_api', dtype=str) base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5) r = requests.get('https://' + base_client_api, timeout=5)
@ -51,9 +51,9 @@ def calc_stats_for_backend(backend_url, running_model, backend_mode):
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now. # was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True, running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0 include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True, running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0 include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -5,7 +5,7 @@ import traceback
import redis as redis_redis import redis as redis_redis
from llm_server import opts from llm_server.config.global_config import GlobalConfig
from llm_server.llm.openai.moderation import check_moderation_endpoint from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger from llm_server.logging import create_logger
@ -29,7 +29,7 @@ def get_results(tag, num_tasks):
num_results = 0 num_results = 0
start_time = time.time() start_time = time.time()
while num_results < num_tasks: while num_results < num_tasks:
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout) result = redis_moderation.blpop(['queue:flagged_categories'], timeout=GlobalConfig.get().openai_moderation_timeout)
if result is None: if result is None:
break # Timeout occurred, break the loop. break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1]) result_tag, categories = json.loads(result[1])
@ -38,7 +38,7 @@ def get_results(tag, num_tasks):
for item in categories: for item in categories:
flagged_categories.add(item) flagged_categories.add(item)
num_results += 1 num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout: if time.time() - start_time > GlobalConfig.get().openai_moderation_timeout:
logger.warning('Timed out waiting for result from moderator') logger.warning('Timed out waiting for result from moderator')
break break
return list(flagged_categories) return list(flagged_categories)

View File

@ -1,8 +1,9 @@
import time import time
from threading import Thread from threading import Thread
from llm_server import opts
from llm_server.cluster.worker import cluster_worker from llm_server.cluster.worker import cluster_worker
from llm_server.config.config import cluster_worker_count
from llm_server.config.global_config import GlobalConfig
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers from llm_server.workers.inferencer import start_workers
@ -21,14 +22,14 @@ def cache_stats():
def start_background(): def start_background():
logger = create_logger('threader') logger = create_logger('threader')
start_workers(opts.cluster) start_workers(GlobalConfig.get().cluster)
t = Thread(target=main_background_thread) t = Thread(target=main_background_thread)
t.daemon = True t.daemon = True
t.start() t.start()
logger.info('Started the main background thread.') logger.info('Started the main background thread.')
num_moderators = opts.cluster_workers * 3 num_moderators = cluster_worker_count() * 3
start_moderation_workers(num_moderators) start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.') logger.info(f'Started {num_moderators} moderation workers.')

View File

@ -15,3 +15,5 @@ redis==5.0.1
ujson==5.8.0 ujson==5.8.0
vllm==0.2.7 vllm==0.2.7
coloredlogs~=15.0.1 coloredlogs~=15.0.1
git+https://git.evulid.cc/cyberes/bison.git
pydantic

View File

@ -1,5 +1,7 @@
import time import time
from llm_server.config.global_config import GlobalConfig
try: try:
import gevent.monkey import gevent.monkey
@ -16,13 +18,12 @@ import simplejson as json
from flask import Flask, jsonify, render_template, request, Response from flask import Flask, jsonify, render_template, request, Response
import config import config
from llm_server import opts
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database, Database from llm_server.database.conn import Database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
@ -69,30 +70,24 @@ if config_path_environ:
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path) success, msg = load_config(config_path)
if not success: if not success:
logger = logging.getLogger('llm_server') logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}') logger.error(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
init_logging(Path(config['webserver_log_directory']) / 'server.log') init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log')
logger = create_logger('Server') logger = create_logger('Server')
logger.debug('Debug logging enabled.') logger.debug('Debug logging enabled.')
try:
import vllm
except ModuleNotFoundError as e:
logger.error(f'Could not import vllm-gptq: {e}')
sys.exit(1)
while not redis.get('daemon_started', dtype=bool): while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
time.sleep(10) time.sleep(10)
logger.info('Started HTTP worker!') logger.info('Started HTTP worker!')
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database']) Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
create_db() create_db()
app = Flask(__name__) app = Flask(__name__)
@ -161,25 +156,25 @@ def home():
break break
return render_template('home.html', return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name, llm_middleware_name=GlobalConfig.get().llm_middleware_name,
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
info_html=info_html, info_html=info_html,
default_model=default_model_info['model'], default_model=default_model_info['model'],
default_active_gen_workers=default_model_info['processing'], default_active_gen_workers=default_model_info['processing'],
default_proompters_in_queue=default_model_info['queued'], default_proompters_in_queue=default_model_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model, current_model=GlobalConfig.get().manual_model_name if GlobalConfig.get().manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}', client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled', ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec, default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.frontend_api_mode][0], mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0],
api_input_textbox=mode_ui_names[opts.frontend_api_mode][1], api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1],
streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2], streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2],
default_context_size=default_model_info['context_size'], default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info, extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', openai_client_api=f'https://{base_client_api}/openai/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt, expose_openai_system_prompt=GlobalConfig.get().expose_openai_system_prompt,
enable_streaming=opts.enable_streaming, enable_streaming=GlobalConfig.get().enable_streaming,
model_choices=model_choices, model_choices=model_choices,
proompters_5_min=stats['stats']['proompters']['5_min'], proompters_5_min=stats['stats']['proompters']['5_min'],
proompters_24_hrs=stats['stats']['proompters']['24_hrs'], proompters_24_hrs=stats['stats']['proompters']['24_hrs'],