fix API key handling

This commit is contained in:
Cyberes 2023-09-26 22:49:53 -06:00
parent d9bbcc42e6
commit 048e5a8060
3 changed files with 28 additions and 15 deletions

View File

@ -1,4 +1,4 @@
import json
import simplejson as json
import traceback
from functools import wraps
from typing import Union
@ -29,7 +29,15 @@ def cache_control(seconds):
return decorator
def require_api_key():
def require_api_key(json_body: dict = None):
if json_body:
request_json = json_body
elif request.headers.get('Content-Type') == 'application/json':
valid_json, request_json = validate_json(request.data)
if not valid_json:
request_json = None
else:
request_json = None
if 'X-Api-Key' in request.headers:
api_key = request.headers['X-Api-Key']
if api_key.startswith('SYSTEM__') or opts.auth_required:
@ -43,12 +51,16 @@ def require_api_key():
if is_valid_api_key(token):
return
else:
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
else:
try:
# Handle websockets
if request.json.get('X-API-KEY'):
api_key = request.json.get('X-API-KEY')
if opts.auth_required and not request_json:
# If we didn't get any valid JSON, deny.
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
if request_json and request_json.get('X-API-KEY'):
api_key = request_json.get('X-API-KEY')
if api_key.startswith('SYSTEM__') or opts.auth_required:
if is_valid_api_key(api_key):
return
@ -57,10 +69,9 @@ def require_api_key():
except:
# TODO: remove this one we're sure this works as expected
traceback.print_exc()
return
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict, bytes]):
if isinstance(data, dict):
return True, data
try:
@ -70,6 +81,9 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
elif isinstance(data, requests.models.Response):
data = data.json()
return True, data
elif isinstance(data, bytes):
s = data.decode('utf-8')
return json.loads(s)
except Exception as e:
return False, e
try:

View File

@ -45,10 +45,10 @@ class RequestHandler:
def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'):
return self.request_json_body.get['X-API-KEY']
return self.request_json_body['X-API-KEY']
elif self.request.headers.get('X-Api-Key'):
return self.request.headers['X-Api-Key']
elif self.request.headers['Authorization']:
elif self.request.headers.get('Authorization'):
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):

View File

@ -21,13 +21,8 @@ from ...stream import sock
@sock.route('/api/v1/stream')
def stream(ws):
if not opts.enable_streaming:
# TODO: return a formatted ST error message
return 'disabled', 401
auth_failure = require_api_key()
if auth_failure:
return auth_failure
r_headers = dict(request.headers)
r_url = request.url
@ -47,6 +42,10 @@ def stream(ws):
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
handler = OobaRequestHandler(request, request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
@ -156,8 +155,8 @@ def stream(ws):
'event': 'stream_end',
'message_num': message_num
}))
ws.close() # this is important if we encountered and error and exited early.
except:
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
ws.close() # this is important if we encountered and error and exited early.