minor adjustments

This commit is contained in:
Cyberes 2023-08-21 22:49:44 -06:00
parent 8cbf643fd3
commit e04d6a8a13
8 changed files with 39 additions and 25 deletions

View File

@ -5,6 +5,6 @@ log_prompts: true
mode: oobabooga
auth_required: false
backend_url: https://proxy.chub-archive.evulid.cc
backend_url: http://172.0.0.2:9104
database_path: ./proxy-server.db

View File

@ -29,7 +29,7 @@ def init_db(db_path):
''')
c.execute('''
CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER)
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
''')
conn.commit()
conn.close()
@ -54,11 +54,12 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
def is_valid_api_key(api_key):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,))
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
row = cursor.fetchone()
if row is not None:
token, uses, max_uses, expire = row
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()):
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
return False

View File

@ -1,3 +1,5 @@
# Global settings that never change after startup
running_model = 'none'
concurrent_generates = 3
mode = 'oobabooga'

View File

@ -11,6 +11,22 @@ from llm_server import opts
from llm_server.database import is_valid_api_key
def cache_control(seconds):
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
resp = make_response(f(*args, **kwargs))
if seconds >= 0:
resp.headers['Cache-Control'] = f'public, max-age={seconds}'
else:
resp.headers['Cache-Control'] = f'no-store'
return resp
return decorated_function
return decorator
def require_api_key():
if not opts.auth_required:
return
@ -23,19 +39,6 @@ def require_api_key():
return jsonify({'code': 401, 'message': 'API key required'}), 401
# def cache_control(seconds):
# def decorator(f):
# @wraps(f)
# def decorated_function(*args, **kwargs):
# resp = make_response(f(*args, **kwargs))
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
# return resp
#
# return decorated_function
#
# return decorator
def validate_json(data: Union[str, Response]):
if isinstance(data, Response):
try:

View File

@ -2,7 +2,7 @@ from flask import jsonify, request
from . import bp
from llm_server.routes.stats import concurrent_semaphore, proompts
from ..helpers.http import validate_json
from ..helpers.http import cache_control, validate_json
from ... import opts
from ...database import log_prompt
@ -17,6 +17,7 @@ elif opts.mode == 'hf-textgen':
@bp.route('/generate', methods=['POST'])
@cache_control(-1)
def generate():
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:

View File

@ -3,6 +3,7 @@ import time
from flask import jsonify
from . import bp
from ..helpers.http import cache_control
from ...llm.oobabooga.info import get_running_model
from ..cache import cache
@ -19,6 +20,7 @@ from ..cache import cache
@bp.route('/model', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
@cache_control(60)
def get_model():
model = get_running_model()
if not model:
@ -26,11 +28,11 @@ def get_model():
'code': 500,
'error': 'failed to reach backend'
}), 500
return jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
else:
return jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
# @openai_bp.route('/models', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)

View File

@ -8,10 +8,12 @@ from . import bp
from .. import stats
from llm_server.routes.v1.generate import concurrent_semaphore
from ..cache import cache
from ..helpers.http import cache_control
@bp.route('/stats', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
@cache_control(60)
def get_stats():
return jsonify({
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,

View File

@ -11,6 +11,7 @@ from llm_server.database import init_db
from llm_server.helpers import resolve_path
from llm_server.llm.oobabooga.info import get_running_model
from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.v1 import bp
config_path_environ = os.getenv("CONFIG_PATH")
@ -46,12 +47,14 @@ cache.init_app(app)
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/')
print(app.url_map)
# print(app.url_map)
@app.route('/')
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
@cache_control(-1)
def fallback(first=None, rest=None):
return jsonify({
'error': 404,