minor adjustments

This commit is contained in:
Cyberes 2023-08-21 22:49:44 -06:00
parent 8cbf643fd3
commit e04d6a8a13
8 changed files with 39 additions and 25 deletions

View File

@ -5,6 +5,6 @@ log_prompts: true
mode: oobabooga mode: oobabooga
auth_required: false auth_required: false
backend_url: https://proxy.chub-archive.evulid.cc backend_url: http://172.0.0.2:9104
database_path: ./proxy-server.db database_path: ./proxy-server.db

View File

@ -29,7 +29,7 @@ def init_db(db_path):
''') ''')
c.execute(''' c.execute('''
CREATE TABLE token_auth CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER) (token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
''') ''')
conn.commit() conn.commit()
conn.close() conn.close()
@ -54,11 +54,12 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
def is_valid_api_key(api_key): def is_valid_api_key(api_key):
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,)) cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
if row is not None: if row is not None:
token, uses, max_uses, expire = row token, uses, max_uses, expire, disabled = row
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()): disabled = bool(disabled)
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True return True
return False return False

View File

@ -1,3 +1,5 @@
# Global settings that never change after startup
running_model = 'none' running_model = 'none'
concurrent_generates = 3 concurrent_generates = 3
mode = 'oobabooga' mode = 'oobabooga'

View File

@ -11,6 +11,22 @@ from llm_server import opts
from llm_server.database import is_valid_api_key from llm_server.database import is_valid_api_key
def cache_control(seconds):
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
resp = make_response(f(*args, **kwargs))
if seconds >= 0:
resp.headers['Cache-Control'] = f'public, max-age={seconds}'
else:
resp.headers['Cache-Control'] = f'no-store'
return resp
return decorated_function
return decorator
def require_api_key(): def require_api_key():
if not opts.auth_required: if not opts.auth_required:
return return
@ -23,19 +39,6 @@ def require_api_key():
return jsonify({'code': 401, 'message': 'API key required'}), 401 return jsonify({'code': 401, 'message': 'API key required'}), 401
# def cache_control(seconds):
# def decorator(f):
# @wraps(f)
# def decorated_function(*args, **kwargs):
# resp = make_response(f(*args, **kwargs))
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
# return resp
#
# return decorated_function
#
# return decorator
def validate_json(data: Union[str, Response]): def validate_json(data: Union[str, Response]):
if isinstance(data, Response): if isinstance(data, Response):
try: try:

View File

@ -2,7 +2,7 @@ from flask import jsonify, request
from . import bp from . import bp
from llm_server.routes.stats import concurrent_semaphore, proompts from llm_server.routes.stats import concurrent_semaphore, proompts
from ..helpers.http import validate_json from ..helpers.http import cache_control, validate_json
from ... import opts from ... import opts
from ...database import log_prompt from ...database import log_prompt
@ -17,6 +17,7 @@ elif opts.mode == 'hf-textgen':
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
@cache_control(-1)
def generate(): def generate():
request_valid_json, request_json_body = validate_json(request.data) request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json: if not request_valid_json:

View File

@ -3,6 +3,7 @@ import time
from flask import jsonify from flask import jsonify
from . import bp from . import bp
from ..helpers.http import cache_control
from ...llm.oobabooga.info import get_running_model from ...llm.oobabooga.info import get_running_model
from ..cache import cache from ..cache import cache
@ -19,6 +20,7 @@ from ..cache import cache
@bp.route('/model', methods=['GET']) @bp.route('/model', methods=['GET'])
@cache.cached(timeout=60, query_string=True) @cache.cached(timeout=60, query_string=True)
@cache_control(60)
def get_model(): def get_model():
model = get_running_model() model = get_running_model()
if not model: if not model:
@ -26,7 +28,7 @@ def get_model():
'code': 500, 'code': 500,
'error': 'failed to reach backend' 'error': 'failed to reach backend'
}), 500 }), 500
else:
return jsonify({ return jsonify({
'result': model, 'result': model,
'timestamp': int(time.time()) 'timestamp': int(time.time())

View File

@ -8,10 +8,12 @@ from . import bp
from .. import stats from .. import stats
from llm_server.routes.v1.generate import concurrent_semaphore from llm_server.routes.v1.generate import concurrent_semaphore
from ..cache import cache from ..cache import cache
from ..helpers.http import cache_control
@bp.route('/stats', methods=['GET']) @bp.route('/stats', methods=['GET'])
@cache.cached(timeout=60, query_string=True) @cache.cached(timeout=60, query_string=True)
@cache_control(60)
def get_stats(): def get_stats():
return jsonify({ return jsonify({
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value, 'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,

View File

@ -11,6 +11,7 @@ from llm_server.database import init_db
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path
from llm_server.llm.oobabooga.info import get_running_model from llm_server.llm.oobabooga.info import get_running_model
from llm_server.routes.cache import cache from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
@ -46,12 +47,14 @@ cache.init_app(app)
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base") # current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(bp, url_prefix='/api/v1/')
print(app.url_map)
# print(app.url_map)
@app.route('/') @app.route('/')
@app.route('/<first>') @app.route('/<first>')
@app.route('/<first>/<path:rest>') @app.route('/<first>/<path:rest>')
@cache_control(-1)
def fallback(first=None, rest=None): def fallback(first=None, rest=None):
return jsonify({ return jsonify({
'error': 404, 'error': 404,