minor adjustments
This commit is contained in:
parent
8cbf643fd3
commit
e04d6a8a13
|
@ -5,6 +5,6 @@ log_prompts: true
|
|||
mode: oobabooga
|
||||
auth_required: false
|
||||
|
||||
backend_url: https://proxy.chub-archive.evulid.cc
|
||||
backend_url: http://172.0.0.2:9104
|
||||
|
||||
database_path: ./proxy-server.db
|
|
@ -29,7 +29,7 @@ def init_db(db_path):
|
|||
''')
|
||||
c.execute('''
|
||||
CREATE TABLE token_auth
|
||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER)
|
||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
@ -54,11 +54,12 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
|||
def is_valid_api_key(api_key):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,))
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire = row
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()):
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Global settings that never change after startup
|
||||
|
||||
running_model = 'none'
|
||||
concurrent_generates = 3
|
||||
mode = 'oobabooga'
|
||||
|
|
|
@ -11,6 +11,22 @@ from llm_server import opts
|
|||
from llm_server.database import is_valid_api_key
|
||||
|
||||
|
||||
def cache_control(seconds):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
resp = make_response(f(*args, **kwargs))
|
||||
if seconds >= 0:
|
||||
resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
||||
else:
|
||||
resp.headers['Cache-Control'] = f'no-store'
|
||||
return resp
|
||||
|
||||
return decorated_function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_api_key():
|
||||
if not opts.auth_required:
|
||||
return
|
||||
|
@ -23,19 +39,6 @@ def require_api_key():
|
|||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
|
||||
|
||||
# def cache_control(seconds):
|
||||
# def decorator(f):
|
||||
# @wraps(f)
|
||||
# def decorated_function(*args, **kwargs):
|
||||
# resp = make_response(f(*args, **kwargs))
|
||||
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
||||
# return resp
|
||||
#
|
||||
# return decorated_function
|
||||
#
|
||||
# return decorator
|
||||
|
||||
|
||||
def validate_json(data: Union[str, Response]):
|
||||
if isinstance(data, Response):
|
||||
try:
|
||||
|
|
|
@ -2,7 +2,7 @@ from flask import jsonify, request
|
|||
|
||||
from . import bp
|
||||
from llm_server.routes.stats import concurrent_semaphore, proompts
|
||||
from ..helpers.http import validate_json
|
||||
from ..helpers.http import cache_control, validate_json
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
|
||||
|
@ -17,6 +17,7 @@ elif opts.mode == 'hf-textgen':
|
|||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
@cache_control(-1)
|
||||
def generate():
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json:
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
from flask import jsonify
|
||||
|
||||
from . import bp
|
||||
from ..helpers.http import cache_control
|
||||
from ...llm.oobabooga.info import get_running_model
|
||||
from ..cache import cache
|
||||
|
||||
|
@ -19,6 +20,7 @@ from ..cache import cache
|
|||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
@cache_control(60)
|
||||
def get_model():
|
||||
model = get_running_model()
|
||||
if not model:
|
||||
|
@ -26,7 +28,7 @@ def get_model():
|
|||
'code': 500,
|
||||
'error': 'failed to reach backend'
|
||||
}), 500
|
||||
|
||||
else:
|
||||
return jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
|
|
|
@ -8,10 +8,12 @@ from . import bp
|
|||
from .. import stats
|
||||
from llm_server.routes.v1.generate import concurrent_semaphore
|
||||
from ..cache import cache
|
||||
from ..helpers.http import cache_control
|
||||
|
||||
|
||||
@bp.route('/stats', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
@cache_control(60)
|
||||
def get_stats():
|
||||
return jsonify({
|
||||
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,
|
||||
|
|
|
@ -11,6 +11,7 @@ from llm_server.database import init_db
|
|||
from llm_server.helpers import resolve_path
|
||||
from llm_server.llm.oobabooga.info import get_running_model
|
||||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.helpers.http import cache_control
|
||||
from llm_server.routes.v1 import bp
|
||||
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
|
@ -46,12 +47,14 @@ cache.init_app(app)
|
|||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
|
||||
print(app.url_map)
|
||||
|
||||
# print(app.url_map)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@app.route('/<first>')
|
||||
@app.route('/<first>/<path:rest>')
|
||||
@cache_control(-1)
|
||||
def fallback(first=None, rest=None):
|
||||
return jsonify({
|
||||
'error': 404,
|
||||
|
|
Reference in New Issue