103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import traceback
|
|
from functools import wraps
|
|
from typing import Union
|
|
|
|
import flask
|
|
import requests
|
|
import simplejson as json
|
|
from flask import Request, make_response
|
|
from flask import jsonify, request
|
|
|
|
from llm_server.config.global_config import GlobalConfig
|
|
from llm_server.database.database import is_valid_api_key
|
|
from llm_server.routes.auth import parse_token
|
|
|
|
|
|
def cache_control(seconds):
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
resp = make_response(f(*args, **kwargs))
|
|
if seconds > 0:
|
|
resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
|
else:
|
|
resp.headers['Cache-Control'] = f'no-store'
|
|
return resp
|
|
|
|
return decorated_function
|
|
|
|
return decorator
|
|
|
|
|
|
# TODO:
|
|
# File "/srv/server/local-llm-server/llm_server/routes/request_handler.py", line 240, in before_request
|
|
# response = require_api_key()
|
|
# ^^^^^^^^^^^^^^^^^
|
|
# File "/srv/server/local-llm-server/llm_server/routes/helpers/http.py", line 50, in require_api_key
|
|
# if token.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
|
# ^^^^^^^^^^^^^^^^
|
|
# AttributeError: 'NoneType' object has no attribute 'startswith'
|
|
|
|
|
|
def require_api_key(json_body: dict = None):
|
|
if json_body:
|
|
request_json = json_body
|
|
elif request.headers.get('Content-Type') == 'application/json':
|
|
valid_json, request_json = validate_json(request.data)
|
|
if not valid_json:
|
|
request_json = None
|
|
else:
|
|
request_json = None
|
|
if 'X-Api-Key' in request.headers:
|
|
api_key = request.headers['X-Api-Key']
|
|
if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
|
if is_valid_api_key(api_key):
|
|
return
|
|
else:
|
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
|
elif 'Authorization' in request.headers:
|
|
token = parse_token(request.headers['Authorization'])
|
|
if (token and token.startswith('SYSTEM__')) or GlobalConfig.get().auth_required:
|
|
if is_valid_api_key(token):
|
|
return
|
|
else:
|
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
|
else:
|
|
try:
|
|
# Handle websockets
|
|
if GlobalConfig.get().auth_required and not request_json:
|
|
# If we didn't get any valid JSON, deny.
|
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
|
|
|
if request_json and request_json.get('X-API-KEY'):
|
|
api_key = request_json.get('X-API-KEY')
|
|
if api_key.startswith('SYSTEM__') or GlobalConfig.get().auth_required:
|
|
if is_valid_api_key(api_key):
|
|
return
|
|
else:
|
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
|
except:
|
|
# TODO: remove this one we're sure this works as expected
|
|
traceback.print_exc()
|
|
|
|
|
|
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict, bytes]):
|
|
if isinstance(data, dict):
|
|
return True, data
|
|
try:
|
|
if isinstance(data, (Request, flask.Response)):
|
|
data = data.json
|
|
return True, data
|
|
elif isinstance(data, requests.models.Response):
|
|
data = data.json()
|
|
return True, data
|
|
elif isinstance(data, bytes):
|
|
s = data.decode('utf-8')
|
|
return False, json.loads(s)
|
|
except Exception as e:
|
|
return False, e
|
|
try:
|
|
j = json.loads(str(data))
|
|
return True, j
|
|
except Exception as e:
|
|
return False, e |