add config setting for hostname

This commit is contained in:
Cyberes 2023-09-23 23:24:08 -06:00
parent 84a1fcfdd8
commit 62412f4873
5 changed files with 17 additions and 10 deletions

View File

@ -22,6 +22,7 @@ config_default_vars = {
'openai_api_key': None,
'expose_openai_system_prompt': True,
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
'http_host': None,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -68,7 +68,7 @@ def round_up_base(n, base):
return math.ceil(n / base) * base
def set_base_client_api(request):
def auto_set_base_client_api(request):
http_host = redis.get('http_host', str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):

View File

@ -6,7 +6,7 @@ from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response
from ..server_error import handle_server_error
from ... import opts
from ...helpers import set_base_client_api
from ...helpers import auto_set_base_client_api
openai_bp = Blueprint('openai/v1/', __name__)
@ -14,7 +14,7 @@ openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request
def before_oai_request():
# TODO: unify with normal before_request()
set_base_client_api(request)
auto_set_base_client_api(request)
if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if request.endpoint != 'v1.get_stats':

View File

@ -7,7 +7,7 @@ from flask import Response, request
from llm_server import opts
from llm_server.database.conn import db_pool
from llm_server.database.database import log_prompt
from llm_server.helpers import set_base_client_api
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
@ -194,7 +194,7 @@ def delete_dict_key(d: dict, k: Union[str, list]):
def before_request():
set_base_client_api(request)
auto_set_base_client_api(request)
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:

View File

@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
import config
from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path, set_base_client_api
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
@ -59,6 +59,10 @@ create_db()
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
@ -81,15 +85,17 @@ opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
if config['http_host']:
redis.set('http_host', config['http_host'])
redis.set('base_client_api', f'{config["http_host"]}/{opts.frontend_api_client.strip("/")}')
print('Set host to', redis.get('http_host', str))
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
@ -204,7 +210,7 @@ def server_error(e):
@app.before_request
def before_app_request():
set_base_client_api(request)
auto_set_base_client_api(request)
if __name__ == "__main__":