diff --git a/llm_server/llm/hf_textgen/generate.py b/llm_server/llm/hf_textgen/generate.py index c0fa4f3..ec70d54 100644 --- a/llm_server/llm/hf_textgen/generate.py +++ b/llm_server/llm/hf_textgen/generate.py @@ -14,7 +14,8 @@ def prepare_json(json_data: dict): seed = None typical_p = json_data.get('typical_p', None) if typical_p >= 1: - typical_p = 0.999 + # https://github.com/huggingface/text-generation-inference/issues/929 + typical_p = 0.998 return { 'inputs': json_data.get('prompt', ''), 'parameters': { diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index a88bdd9..de55711 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -45,7 +45,7 @@ class OobaRequestHandler: def get_parameters(self): request_valid_json, self.request_json_body = validate_json(self.request.data) if not request_valid_json: - raise InvalidJSONError + return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 parameters = self.request_json_body.copy() del parameters['prompt'] return parameters diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index cf04c1e..668856a 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -1,15 +1,15 @@ from flask import jsonify, request -from requests.exceptions import InvalidJSONError from . import bp +from ..helpers.http import validate_json from ..request_handler import OobaRequestHandler @bp.route('/generate', methods=['POST']) def generate(): - try: + request_valid_json, request_json_body = validate_json(request.data) + if not request_valid_json or not request_json_body.get('prompt') or not request_json_body.get('parameters'): + return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 + else: handler = OobaRequestHandler(request) return handler.handle_request() - except InvalidJSONError: - # The request handler will throw an error if the client sent invalid JSON - return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400