From 69b8c1e35c81f5c1361d81cee2de720b7c137dd6 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 11 Oct 2023 12:50:20 -0600 Subject: [PATCH] fix openai confusion --- llm_server/llm/openai/oai_to_vllm.py | 3 +++ llm_server/llm/vllm/vllm_backend.py | 3 ++- llm_server/routes/openai/chat_completions.py | 3 --- llm_server/routes/openai_request_handler.py | 28 +++++++++++++------- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index cde5180..ef07a08 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -28,6 +28,9 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode): request_json_body['top_p'] = 0.01 request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens) + if request_json_body['max_tokens'] == 0: + # We don't want to set any defaults here. + del request_json_body['max_tokens'] return request_json_body diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 835d2ce..9665547 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -40,7 +40,8 @@ class VLLMBackend(LLMBackend): max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']), presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']), - length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']) + length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']), + early_stopping=parameters.get('early_stopping', self._default_params['early_stopping']) ) except ValueError as e: return None, str(e).strip('.') diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index d054703..afa6fd1 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -34,9 +34,6 @@ def openai_chat_completions(model_name=None): if not request_json_body.get('stream'): try: - invalid_oai_err_msg = validate_oai(request_json_body) - if invalid_oai_err_msg: - return invalid_oai_err_msg return handler.handle_request() except Exception: traceback.print_exc() diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 3c2a5b1..246c3b6 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -38,10 +38,23 @@ class OpenAIRequestHandler(RequestHandler): oai_messages = self.request.json['messages'] self.prompt = transform_messages_to_prompt(oai_messages) + self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + request_valid, invalid_response = self.validate_request() if not request_valid: return invalid_response + if not self.prompt: + # TODO: format this as an openai error message + return Response('Invalid prompt'), 400 + + # TODO: support Ooba backend + self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + + invalid_oai_err_msg = validate_oai(self.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token): try: # Gather the last message from the user and all preceding system messages @@ -62,13 +75,6 @@ class OpenAIRequestHandler(RequestHandler): print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') traceback.print_exc() - # TODO: support Ooba - self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) - - if not self.prompt: - # TODO: format this as an openai error message - return Response('Invalid prompt'), 400 - llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) model = self.request_json_body.get('model') @@ -152,9 +158,13 @@ class OpenAIRequestHandler(RequestHandler): return response def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: - invalid_oai_err_msg = validate_oai(self.request_json_body) + self.parameters, parameters_invalid_msg = self.get_parameters() + if not self.parameters: + print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg) + return False, (Response('Invalid request, check your parameters and try again.'), 400) + invalid_oai_err_msg = validate_oai(self.parameters) if invalid_oai_err_msg: return False, invalid_oai_err_msg - self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + # self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) # If the parameters were invalid, let the superclass deal with it. return super().validate_request(prompt, do_log)