73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import tiktoken
|
|
from flask import Flask, current_app, jsonify
|
|
|
|
from llm_server import opts
|
|
from llm_server.config import ConfigLoader
|
|
from llm_server.database import init_db
|
|
from llm_server.helpers import resolve_path
|
|
from llm_server.llm.oobabooga.info import get_running_model
|
|
from llm_server.routes.cache import cache
|
|
from llm_server.routes.helpers.http import cache_control
|
|
from llm_server.routes.v1 import bp
|
|
|
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
config_path_environ = os.getenv("CONFIG_PATH")
|
|
if config_path_environ:
|
|
config_path = config_path_environ
|
|
else:
|
|
config_path = Path(script_path, 'config', 'config.yml')
|
|
|
|
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False}
|
|
required_vars = []
|
|
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
|
success, config, msg = config_loader.load_config()
|
|
if not success:
|
|
print('Failed to load config:', msg)
|
|
sys.exit(1)
|
|
|
|
opts.backend_url = config['backend_url'].strip('/')
|
|
|
|
# Resolve relative directory to the directory of the script
|
|
if config['database_path'].startswith('./'):
|
|
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
|
|
|
opts.database_path = resolve_path(config['database_path'])
|
|
init_db(opts.database_path)
|
|
|
|
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
|
print('Unknown mode:', config['mode'])
|
|
opts.mode = config['mode']
|
|
opts.auth_required = config['auth_required']
|
|
opts.log_prompts = config['log_prompts']
|
|
|
|
opts.running_model = get_running_model()
|
|
|
|
app = Flask(__name__)
|
|
cache.init_app(app)
|
|
# with app.app_context():
|
|
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
|
|
|
|
|
# print(app.url_map)
|
|
|
|
|
|
@app.route('/')
|
|
@app.route('/<first>')
|
|
@app.route('/<first>/<path:rest>')
|
|
@cache_control(-1)
|
|
def fallback(first=None, rest=None):
|
|
return jsonify({
|
|
'error': 404,
|
|
'msg': 'not found'
|
|
}), 404
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host='0.0.0.0')
|