fix API key handling
This commit is contained in:
parent
d9bbcc42e6
commit
048e5a8060
|
@ -1,4 +1,4 @@
|
||||||
import json
|
import simplejson as json
|
||||||
import traceback
|
import traceback
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
@ -29,7 +29,15 @@ def cache_control(seconds):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def require_api_key():
|
def require_api_key(json_body: dict = None):
|
||||||
|
if json_body:
|
||||||
|
request_json = json_body
|
||||||
|
elif request.headers.get('Content-Type') == 'application/json':
|
||||||
|
valid_json, request_json = validate_json(request.data)
|
||||||
|
if not valid_json:
|
||||||
|
request_json = None
|
||||||
|
else:
|
||||||
|
request_json = None
|
||||||
if 'X-Api-Key' in request.headers:
|
if 'X-Api-Key' in request.headers:
|
||||||
api_key = request.headers['X-Api-Key']
|
api_key = request.headers['X-Api-Key']
|
||||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||||
|
@ -43,12 +51,16 @@ def require_api_key():
|
||||||
if is_valid_api_key(token):
|
if is_valid_api_key(token):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
return jsonify({'code': 403, 'message': 'Invalid token'}), 403
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Handle websockets
|
# Handle websockets
|
||||||
if request.json.get('X-API-KEY'):
|
if opts.auth_required and not request_json:
|
||||||
api_key = request.json.get('X-API-KEY')
|
# If we didn't get any valid JSON, deny.
|
||||||
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
|
|
||||||
|
if request_json and request_json.get('X-API-KEY'):
|
||||||
|
api_key = request_json.get('X-API-KEY')
|
||||||
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
if api_key.startswith('SYSTEM__') or opts.auth_required:
|
||||||
if is_valid_api_key(api_key):
|
if is_valid_api_key(api_key):
|
||||||
return
|
return
|
||||||
|
@ -57,10 +69,9 @@ def require_api_key():
|
||||||
except:
|
except:
|
||||||
# TODO: remove this one we're sure this works as expected
|
# TODO: remove this one we're sure this works as expected
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict, bytes]):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
return True, data
|
return True, data
|
||||||
try:
|
try:
|
||||||
|
@ -70,6 +81,9 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
|
||||||
elif isinstance(data, requests.models.Response):
|
elif isinstance(data, requests.models.Response):
|
||||||
data = data.json()
|
data = data.json()
|
||||||
return True, data
|
return True, data
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
s = data.decode('utf-8')
|
||||||
|
return json.loads(s)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, e
|
return False, e
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -45,10 +45,10 @@ class RequestHandler:
|
||||||
|
|
||||||
def get_auth_token(self):
|
def get_auth_token(self):
|
||||||
if self.request_json_body.get('X-API-KEY'):
|
if self.request_json_body.get('X-API-KEY'):
|
||||||
return self.request_json_body.get['X-API-KEY']
|
return self.request_json_body['X-API-KEY']
|
||||||
elif self.request.headers.get('X-Api-Key'):
|
elif self.request.headers.get('X-Api-Key'):
|
||||||
return self.request.headers['X-Api-Key']
|
return self.request.headers['X-Api-Key']
|
||||||
elif self.request.headers['Authorization']:
|
elif self.request.headers.get('Authorization'):
|
||||||
return parse_token(self.request.headers['Authorization'])
|
return parse_token(self.request.headers['Authorization'])
|
||||||
|
|
||||||
def get_client_ip(self):
|
def get_client_ip(self):
|
||||||
|
|
|
@ -21,13 +21,8 @@ from ...stream import sock
|
||||||
@sock.route('/api/v1/stream')
|
@sock.route('/api/v1/stream')
|
||||||
def stream(ws):
|
def stream(ws):
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
# TODO: return a formatted ST error message
|
|
||||||
return 'disabled', 401
|
return 'disabled', 401
|
||||||
|
|
||||||
auth_failure = require_api_key()
|
|
||||||
if auth_failure:
|
|
||||||
return auth_failure
|
|
||||||
|
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
|
|
||||||
|
@ -47,6 +42,10 @@ def stream(ws):
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
auth_failure = require_api_key(request_json_body)
|
||||||
|
if auth_failure:
|
||||||
|
return auth_failure
|
||||||
|
|
||||||
handler = OobaRequestHandler(request, request_json_body)
|
handler = OobaRequestHandler(request, request_json_body)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
input_prompt = request_json_body['prompt']
|
input_prompt = request_json_body['prompt']
|
||||||
|
@ -156,8 +155,8 @@ def stream(ws):
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
'message_num': message_num
|
'message_num': message_num
|
||||||
}))
|
}))
|
||||||
ws.close() # this is important if we encountered and error and exited early.
|
|
||||||
except:
|
except:
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
|
ws.close() # this is important if we encountered and error and exited early.
|
||||||
|
|
Reference in New Issue