minor adjustments
This commit is contained in:
parent
8cbf643fd3
commit
e04d6a8a13
|
@ -5,6 +5,6 @@ log_prompts: true
|
||||||
mode: oobabooga
|
mode: oobabooga
|
||||||
auth_required: false
|
auth_required: false
|
||||||
|
|
||||||
backend_url: https://proxy.chub-archive.evulid.cc
|
backend_url: http://172.0.0.2:9104
|
||||||
|
|
||||||
database_path: ./proxy-server.db
|
database_path: ./proxy-server.db
|
|
@ -29,7 +29,7 @@ def init_db(db_path):
|
||||||
''')
|
''')
|
||||||
c.execute('''
|
c.execute('''
|
||||||
CREATE TABLE token_auth
|
CREATE TABLE token_auth
|
||||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER)
|
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
||||||
''')
|
''')
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
@ -54,11 +54,12 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
||||||
def is_valid_api_key(api_key):
|
def is_valid_api_key(api_key):
|
||||||
conn = sqlite3.connect(opts.database_path)
|
conn = sqlite3.connect(opts.database_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,))
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
token, uses, max_uses, expire = row
|
token, uses, max_uses, expire, disabled = row
|
||||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()):
|
disabled = bool(disabled)
|
||||||
|
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Global settings that never change after startup
|
||||||
|
|
||||||
running_model = 'none'
|
running_model = 'none'
|
||||||
concurrent_generates = 3
|
concurrent_generates = 3
|
||||||
mode = 'oobabooga'
|
mode = 'oobabooga'
|
||||||
|
|
|
@ -11,6 +11,22 @@ from llm_server import opts
|
||||||
from llm_server.database import is_valid_api_key
|
from llm_server.database import is_valid_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def cache_control(seconds):
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated_function(*args, **kwargs):
|
||||||
|
resp = make_response(f(*args, **kwargs))
|
||||||
|
if seconds >= 0:
|
||||||
|
resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
||||||
|
else:
|
||||||
|
resp.headers['Cache-Control'] = f'no-store'
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return decorated_function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def require_api_key():
|
def require_api_key():
|
||||||
if not opts.auth_required:
|
if not opts.auth_required:
|
||||||
return
|
return
|
||||||
|
@ -23,19 +39,6 @@ def require_api_key():
|
||||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||||
|
|
||||||
|
|
||||||
# def cache_control(seconds):
|
|
||||||
# def decorator(f):
|
|
||||||
# @wraps(f)
|
|
||||||
# def decorated_function(*args, **kwargs):
|
|
||||||
# resp = make_response(f(*args, **kwargs))
|
|
||||||
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
|
||||||
# return resp
|
|
||||||
#
|
|
||||||
# return decorated_function
|
|
||||||
#
|
|
||||||
# return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def validate_json(data: Union[str, Response]):
|
def validate_json(data: Union[str, Response]):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -2,7 +2,7 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from llm_server.routes.stats import concurrent_semaphore, proompts
|
from llm_server.routes.stats import concurrent_semaphore, proompts
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import cache_control, validate_json
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database import log_prompt
|
from ...database import log_prompt
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ elif opts.mode == 'hf-textgen':
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/generate', methods=['POST'])
|
@bp.route('/generate', methods=['POST'])
|
||||||
|
@cache_control(-1)
|
||||||
def generate():
|
def generate():
|
||||||
request_valid_json, request_json_body = validate_json(request.data)
|
request_valid_json, request_json_body = validate_json(request.data)
|
||||||
if not request_valid_json:
|
if not request_valid_json:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import time
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
|
from ..helpers.http import cache_control
|
||||||
from ...llm.oobabooga.info import get_running_model
|
from ...llm.oobabooga.info import get_running_model
|
||||||
from ..cache import cache
|
from ..cache import cache
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ from ..cache import cache
|
||||||
|
|
||||||
@bp.route('/model', methods=['GET'])
|
@bp.route('/model', methods=['GET'])
|
||||||
@cache.cached(timeout=60, query_string=True)
|
@cache.cached(timeout=60, query_string=True)
|
||||||
|
@cache_control(60)
|
||||||
def get_model():
|
def get_model():
|
||||||
model = get_running_model()
|
model = get_running_model()
|
||||||
if not model:
|
if not model:
|
||||||
|
@ -26,7 +28,7 @@ def get_model():
|
||||||
'code': 500,
|
'code': 500,
|
||||||
'error': 'failed to reach backend'
|
'error': 'failed to reach backend'
|
||||||
}), 500
|
}), 500
|
||||||
|
else:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'result': model,
|
'result': model,
|
||||||
'timestamp': int(time.time())
|
'timestamp': int(time.time())
|
||||||
|
|
|
@ -8,10 +8,12 @@ from . import bp
|
||||||
from .. import stats
|
from .. import stats
|
||||||
from llm_server.routes.v1.generate import concurrent_semaphore
|
from llm_server.routes.v1.generate import concurrent_semaphore
|
||||||
from ..cache import cache
|
from ..cache import cache
|
||||||
|
from ..helpers.http import cache_control
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/stats', methods=['GET'])
|
@bp.route('/stats', methods=['GET'])
|
||||||
@cache.cached(timeout=60, query_string=True)
|
@cache.cached(timeout=60, query_string=True)
|
||||||
|
@cache_control(60)
|
||||||
def get_stats():
|
def get_stats():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,
|
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,
|
||||||
|
|
|
@ -11,6 +11,7 @@ from llm_server.database import init_db
|
||||||
from llm_server.helpers import resolve_path
|
from llm_server.helpers import resolve_path
|
||||||
from llm_server.llm.oobabooga.info import get_running_model
|
from llm_server.llm.oobabooga.info import get_running_model
|
||||||
from llm_server.routes.cache import cache
|
from llm_server.routes.cache import cache
|
||||||
|
from llm_server.routes.helpers.http import cache_control
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
|
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
|
@ -46,12 +47,14 @@ cache.init_app(app)
|
||||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
|
|
||||||
print(app.url_map)
|
|
||||||
|
# print(app.url_map)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@app.route('/<first>')
|
@app.route('/<first>')
|
||||||
@app.route('/<first>/<path:rest>')
|
@app.route('/<first>/<path:rest>')
|
||||||
|
@cache_control(-1)
|
||||||
def fallback(first=None, rest=None):
|
def fallback(first=None, rest=None):
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'error': 404,
|
'error': 404,
|
||||||
|
|
Reference in New Issue