fix openai confusion
This commit is contained in:
parent
1d1c45dc1a
commit
69b8c1e35c
|
@ -28,6 +28,9 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
|||
request_json_body['top_p'] = 0.01
|
||||
|
||||
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
||||
if request_json_body['max_tokens'] == 0:
|
||||
# We don't want to set any defaults here.
|
||||
del request_json_body['max_tokens']
|
||||
|
||||
return request_json_body
|
||||
|
||||
|
|
|
@ -40,7 +40,8 @@ class VLLMBackend(LLMBackend):
|
|||
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
|
||||
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty'])
|
||||
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
|
||||
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
|
||||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
|
|
|
@ -34,9 +34,6 @@ def openai_chat_completions(model_name=None):
|
|||
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -38,10 +38,23 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
oai_messages = self.request.json['messages']
|
||||
|
||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
if not self.prompt:
|
||||
# TODO: format this as an openai error message
|
||||
return Response('Invalid prompt'), 400
|
||||
|
||||
# TODO: support Ooba backend
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
|
||||
try:
|
||||
# Gather the last message from the user and all preceding system messages
|
||||
|
@ -62,13 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||
traceback.print_exc()
|
||||
|
||||
# TODO: support Ooba
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
if not self.prompt:
|
||||
# TODO: format this as an openai error message
|
||||
return Response('Invalid prompt'), 400
|
||||
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
model = self.request_json_body.get('model')
|
||||
|
@ -152,9 +158,13 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return response
|
||||
|
||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||
if not self.parameters:
|
||||
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
|
||||
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
||||
invalid_oai_err_msg = validate_oai(self.parameters)
|
||||
if invalid_oai_err_msg:
|
||||
return False, invalid_oai_err_msg
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
# If the parameters were invalid, let the superclass deal with it.
|
||||
return super().validate_request(prompt, do_log)
|
||||
|
|
Reference in New Issue