fix openai confusion
This commit is contained in:
parent
1d1c45dc1a
commit
69b8c1e35c
|
@ -28,6 +28,9 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||||
request_json_body['top_p'] = 0.01
|
request_json_body['top_p'] = 0.01
|
||||||
|
|
||||||
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
||||||
|
if request_json_body['max_tokens'] == 0:
|
||||||
|
# We don't want to set any defaults here.
|
||||||
|
del request_json_body['max_tokens']
|
||||||
|
|
||||||
return request_json_body
|
return request_json_body
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,8 @@ class VLLMBackend(LLMBackend):
|
||||||
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
|
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||||
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||||
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
|
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
|
||||||
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty'])
|
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
|
||||||
|
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return None, str(e).strip('.')
|
return None, str(e).strip('.')
|
||||||
|
|
|
@ -34,9 +34,6 @@ def openai_chat_completions(model_name=None):
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
|
||||||
if invalid_oai_err_msg:
|
|
||||||
return invalid_oai_err_msg
|
|
||||||
return handler.handle_request()
|
return handler.handle_request()
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -38,10 +38,23 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
oai_messages = self.request.json['messages']
|
oai_messages = self.request.json['messages']
|
||||||
|
|
||||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||||
|
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
|
|
||||||
|
if not self.prompt:
|
||||||
|
# TODO: format this as an openai error message
|
||||||
|
return Response('Invalid prompt'), 400
|
||||||
|
|
||||||
|
# TODO: support Ooba backend
|
||||||
|
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
|
|
||||||
|
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||||
|
if invalid_oai_err_msg:
|
||||||
|
return invalid_oai_err_msg
|
||||||
|
|
||||||
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
|
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
|
||||||
try:
|
try:
|
||||||
# Gather the last message from the user and all preceding system messages
|
# Gather the last message from the user and all preceding system messages
|
||||||
|
@ -62,13 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# TODO: support Ooba
|
|
||||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
|
||||||
|
|
||||||
if not self.prompt:
|
|
||||||
# TODO: format this as an openai error message
|
|
||||||
return Response('Invalid prompt'), 400
|
|
||||||
|
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
model = self.request_json_body.get('model')
|
model = self.request_json_body.get('model')
|
||||||
|
@ -152,9 +158,13 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||||
|
if not self.parameters:
|
||||||
|
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
|
||||||
|
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
||||||
|
invalid_oai_err_msg = validate_oai(self.parameters)
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
return False, invalid_oai_err_msg
|
return False, invalid_oai_err_msg
|
||||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
# If the parameters were invalid, let the superclass deal with it.
|
# If the parameters were invalid, let the superclass deal with it.
|
||||||
return super().validate_request(prompt, do_log)
|
return super().validate_request(prompt, do_log)
|
||||||
|
|
Reference in New Issue