fix openai confusion

This commit is contained in:
Cyberes 2023-10-11 12:50:20 -06:00
parent 1d1c45dc1a
commit 69b8c1e35c
4 changed files with 24 additions and 13 deletions

View File

@ -28,6 +28,9 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
request_json_body['top_p'] = 0.01
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
if request_json_body['max_tokens'] == 0:
# We don't want to set any defaults here.
del request_json_body['max_tokens']
return request_json_body

View File

@ -40,7 +40,8 @@ class VLLMBackend(LLMBackend):
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty'])
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
)
except ValueError as e:
return None, str(e).strip('.')

View File

@ -34,9 +34,6 @@ def openai_chat_completions(model_name=None):
if not request_json_body.get('stream'):
try:
invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
return handler.handle_request()
except Exception:
traceback.print_exc()

View File

@ -38,10 +38,23 @@ class OpenAIRequestHandler(RequestHandler):
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try:
# Gather the last message from the user and all preceding system messages
@ -62,13 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
# TODO: support Ooba
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
@ -152,9 +158,13 @@ class OpenAIRequestHandler(RequestHandler):
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
invalid_oai_err_msg = validate_oai(self.request_json_body)
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)