add config setting for hostname
This commit is contained in:
parent
84a1fcfdd8
commit
62412f4873
|
@ -22,6 +22,7 @@ config_default_vars = {
|
|||
'openai_api_key': None,
|
||||
'expose_openai_system_prompt': True,
|
||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
||||
'http_host': None,
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ def round_up_base(n, base):
|
|||
return math.ceil(n / base) * base
|
||||
|
||||
|
||||
def set_base_client_api(request):
|
||||
def auto_set_base_client_api(request):
|
||||
http_host = redis.get('http_host', str)
|
||||
host = request.headers.get("Host")
|
||||
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
|
||||
|
|
|
@ -6,7 +6,7 @@ from ..helpers.http import require_api_key
|
|||
from ..openai_request_handler import build_openai_response
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
from ...helpers import set_base_client_api
|
||||
from ...helpers import auto_set_base_client_api
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
|
||||
|
@ -14,7 +14,7 @@ openai_bp = Blueprint('openai/v1/', __name__)
|
|||
@openai_bp.before_request
|
||||
def before_oai_request():
|
||||
# TODO: unify with normal before_request()
|
||||
set_base_client_api(request)
|
||||
auto_set_base_client_api(request)
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
|
|
|
@ -7,7 +7,7 @@ from flask import Response, request
|
|||
from llm_server import opts
|
||||
from llm_server.database.conn import db_pool
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.helpers import set_base_client_api
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
|
@ -194,7 +194,7 @@ def delete_dict_key(d: dict, k: Union[str, list]):
|
|||
|
||||
|
||||
def before_request():
|
||||
set_base_client_api(request)
|
||||
auto_set_base_client_api(request)
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
|
|
16
server.py
16
server.py
|
@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
|
|||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||
from llm_server.helpers import resolve_path, set_base_client_api
|
||||
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.queue import start_workers
|
||||
|
@ -59,6 +59,10 @@ create_db()
|
|||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
|
@ -81,15 +85,17 @@ opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
|||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
|
||||
if config['http_host']:
|
||||
redis.set('http_host', config['http_host'])
|
||||
redis.set('base_client_api', f'{config["http_host"]}/{opts.frontend_api_client.strip("/")}')
|
||||
print('Set host to', redis.get('http_host', str))
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
|
@ -204,7 +210,7 @@ def server_error(e):
|
|||
|
||||
@app.before_request
|
||||
def before_app_request():
|
||||
set_base_client_api(request)
|
||||
auto_set_base_client_api(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Reference in New Issue