fix API key handling
This commit is contained in:
parent
d9bbcc42e6
commit
048e5a8060
|
@ -1,4 +1,4 @@
|
|||
import json
|
||||
import simplejson as json
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from typing import Union
|
||||
|
@ -29,7 +29,15 @@ def cache_control(seconds):
|
|||
return decorator
|
||||
|
||||
|
||||
def require_api_key():
|
||||
def require_api_key(json_body: dict = None):
|
||||
if json_body:
|
||||
request_json = json_body
|
||||
elif request.headers.get('Content-Type') == 'application/json':
|
||||
valid_json, request_json = validate_json(request.data)
|
||||
if not valid_json:
|
||||
request_json = None
|
||||
else:
|
||||
request_json = None
|
||||
if 'X-Api-Key' in request.headers:
|
||||
api_key = request.headers['X-Api-Key']
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
|
@ -43,12 +51,16 @@ def require_api_key():
|
|||
if is_valid_api_key(token):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
else:
|
||||
try:
|
||||
# Handle websockets
|
||||
if request.json.get('X-API-KEY'):
|
||||
api_key = request.json.get('X-API-KEY')
|
||||
if opts.auth_required and not request_json:
|
||||
# If we didn't get any valid JSON, deny.
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
|
||||
if request_json and request_json.get('X-API-KEY'):
|
||||
api_key = request_json.get('X-API-KEY')
|
||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||
if is_valid_api_key(api_key):
|
||||
return
|
||||
|
@ -57,10 +69,9 @@ def require_api_key():
|
|||
except:
|
||||
# TODO: remove this one we're sure this works as expected
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict, bytes]):
|
||||
if isinstance(data, dict):
|
||||
return True, data
|
||||
try:
|
||||
|
@ -70,6 +81,9 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
|
|||
elif isinstance(data, requests.models.Response):
|
||||
data = data.json()
|
||||
return True, data
|
||||
elif isinstance(data, bytes):
|
||||
s = data.decode('utf-8')
|
||||
return json.loads(s)
|
||||
except Exception as e:
|
||||
return False, e
|
||||
try:
|
||||
|
|
|
@ -45,10 +45,10 @@ class RequestHandler:
|
|||
|
||||
def get_auth_token(self):
|
||||
if self.request_json_body.get('X-API-KEY'):
|
||||
return self.request_json_body.get['X-API-KEY']
|
||||
return self.request_json_body['X-API-KEY']
|
||||
elif self.request.headers.get('X-Api-Key'):
|
||||
return self.request.headers['X-Api-Key']
|
||||
elif self.request.headers['Authorization']:
|
||||
elif self.request.headers.get('Authorization'):
|
||||
return parse_token(self.request.headers['Authorization'])
|
||||
|
||||
def get_client_ip(self):
|
||||
|
|
|
@ -21,13 +21,8 @@ from ...stream import sock
|
|||
@sock.route('/api/v1/stream')
|
||||
def stream(ws):
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a formatted ST error message
|
||||
return 'disabled', 401
|
||||
|
||||
auth_failure = require_api_key()
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
|
||||
|
@ -47,6 +42,10 @@ def stream(ws):
|
|||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
|
@ -156,8 +155,8 @@ def stream(ws):
|
|||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
except:
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
|
|
Reference in New Issue