missed a spot, clean up json error handling
This commit is contained in:
parent
8c04238e04
commit
47887c3925
|
@ -14,7 +14,8 @@ def prepare_json(json_data: dict):
|
||||||
seed = None
|
seed = None
|
||||||
typical_p = json_data.get('typical_p', None)
|
typical_p = json_data.get('typical_p', None)
|
||||||
if typical_p >= 1:
|
if typical_p >= 1:
|
||||||
typical_p = 0.999
|
# https://github.com/huggingface/text-generation-inference/issues/929
|
||||||
|
typical_p = 0.998
|
||||||
return {
|
return {
|
||||||
'inputs': json_data.get('prompt', ''),
|
'inputs': json_data.get('prompt', ''),
|
||||||
'parameters': {
|
'parameters': {
|
||||||
|
|
|
@ -45,7 +45,7 @@ class OobaRequestHandler:
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||||
if not request_valid_json:
|
if not request_valid_json:
|
||||||
raise InvalidJSONError
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
parameters = self.request_json_body.copy()
|
parameters = self.request_json_body.copy()
|
||||||
del parameters['prompt']
|
del parameters['prompt']
|
||||||
return parameters
|
return parameters
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from requests.exceptions import InvalidJSONError
|
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
|
from ..helpers.http import validate_json
|
||||||
from ..request_handler import OobaRequestHandler
|
from ..request_handler import OobaRequestHandler
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/generate', methods=['POST'])
|
@bp.route('/generate', methods=['POST'])
|
||||||
def generate():
|
def generate():
|
||||||
try:
|
request_valid_json, request_json_body = validate_json(request.data)
|
||||||
|
if not request_valid_json or not request_json_body.get('prompt') or not request_json_body.get('parameters'):
|
||||||
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
|
else:
|
||||||
handler = OobaRequestHandler(request)
|
handler = OobaRequestHandler(request)
|
||||||
return handler.handle_request()
|
return handler.handle_request()
|
||||||
except InvalidJSONError:
|
|
||||||
# The request handler will throw an error if the client sent invalid JSON
|
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
|
||||||
|
|
Reference in New Issue