missed a spot, clean up json error handling

This commit is contained in:
Cyberes 2023-08-30 20:19:23 -06:00
parent 8c04238e04
commit 47887c3925
3 changed files with 8 additions and 7 deletions

View File

@ -14,7 +14,8 @@ def prepare_json(json_data: dict):
seed = None seed = None
typical_p = json_data.get('typical_p', None) typical_p = json_data.get('typical_p', None)
if typical_p >= 1: if typical_p >= 1:
typical_p = 0.999 # https://github.com/huggingface/text-generation-inference/issues/929
typical_p = 0.998
return { return {
'inputs': json_data.get('prompt', ''), 'inputs': json_data.get('prompt', ''),
'parameters': { 'parameters': {

View File

@ -45,7 +45,7 @@ class OobaRequestHandler:
def get_parameters(self): def get_parameters(self):
request_valid_json, self.request_json_body = validate_json(self.request.data) request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json: if not request_valid_json:
raise InvalidJSONError return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
parameters = self.request_json_body.copy() parameters = self.request_json_body.copy()
del parameters['prompt'] del parameters['prompt']
return parameters return parameters

View File

@ -1,15 +1,15 @@
from flask import jsonify, request from flask import jsonify, request
from requests.exceptions import InvalidJSONError
from . import bp from . import bp
from ..helpers.http import validate_json
from ..request_handler import OobaRequestHandler from ..request_handler import OobaRequestHandler
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
try: request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not request_json_body.get('prompt') or not request_json_body.get('parameters'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request) handler = OobaRequestHandler(request)
return handler.handle_request() return handler.handle_request()
except InvalidJSONError:
# The request handler will throw an error if the client sent invalid JSON
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400