2023-10-27 19:19:22 -06:00
|
|
|
from flask import jsonify
|
|
|
|
|
2024-05-07 12:20:53 -06:00
|
|
|
from llm_server.config.global_config import GlobalConfig
|
2024-05-07 09:48:51 -06:00
|
|
|
from llm_server.logging import create_logger
|
|
|
|
|
|
|
|
_logger = create_logger('oai_to_vllm')
|
2023-10-27 19:19:22 -06:00
|
|
|
|
|
|
|
|
|
|
|
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
|
|
|
if not request_json_body.get('stop'):
|
|
|
|
request_json_body['stop'] = []
|
|
|
|
if not isinstance(request_json_body['stop'], list):
|
|
|
|
# It is a string, so create a list with the existing element.
|
|
|
|
request_json_body['stop'] = [request_json_body['stop']]
|
|
|
|
|
|
|
|
if stop_hashes:
|
2024-05-07 12:20:53 -06:00
|
|
|
if GlobalConfig.get().openai_force_no_hashes:
|
2023-10-27 19:19:22 -06:00
|
|
|
request_json_body['stop'].append('###')
|
|
|
|
else:
|
|
|
|
# TODO: make stopping strings a configurable
|
|
|
|
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
|
|
|
|
else:
|
|
|
|
request_json_body['stop'].extend(['user:', 'assistant:'])
|
|
|
|
|
|
|
|
if request_json_body.get('frequency_penalty', 0) < -2:
|
|
|
|
request_json_body['frequency_penalty'] = -2
|
|
|
|
elif request_json_body.get('frequency_penalty', 0) > 2:
|
|
|
|
request_json_body['frequency_penalty'] = 2
|
|
|
|
|
|
|
|
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
|
|
|
request_json_body['top_p'] = 0.01
|
|
|
|
|
2024-05-07 12:20:53 -06:00
|
|
|
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), GlobalConfig.get().max_new_tokens)
|
2023-10-27 19:19:22 -06:00
|
|
|
if request_json_body['max_tokens'] == 0:
|
|
|
|
# We don't want to set any defaults here.
|
|
|
|
del request_json_body['max_tokens']
|
|
|
|
|
|
|
|
return request_json_body
|
|
|
|
|
|
|
|
|
|
|
|
def format_oai_err(err_msg):
|
2024-05-07 09:48:51 -06:00
|
|
|
_logger.error(f'Got an OAI error message: {err_msg}')
|
2023-10-27 19:19:22 -06:00
|
|
|
return jsonify({
|
|
|
|
"error": {
|
|
|
|
"message": err_msg,
|
|
|
|
"type": "invalid_request_error",
|
|
|
|
"param": None,
|
|
|
|
"code": None
|
|
|
|
}
|
|
|
|
}), 400
|
|
|
|
|
|
|
|
|
|
|
|
def validate_oai(parameters):
|
|
|
|
if parameters.get('messages'):
|
|
|
|
for m in parameters['messages']:
|
|
|
|
if m['role'].lower() not in ['assistant', 'user', 'system']:
|
|
|
|
return format_oai_err('messages role must be assistant, user, or system')
|
|
|
|
|
|
|
|
if parameters.get('temperature', 0) > 2:
|
|
|
|
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
|
|
|
if parameters.get('temperature', 0) < 0:
|
|
|
|
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
|
|
|
|
|
|
|
|
if parameters.get('top_p', 1) > 2:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
|
|
|
if parameters.get('top_p', 1) < 0:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
|
|
|
|
|
|
|
if parameters.get('presence_penalty', 1) > 2:
|
|
|
|
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
|
|
|
|
if parameters.get('presence_penalty', 1) < -2:
|
|
|
|
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
|
|
|
|
|
|
|
|
if parameters.get('top_p', 1) > 2:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
|
|
|
if parameters.get('top_p', 1) < 0:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
|
|
|
|
|
|
|
if parameters.get('top_p', 1) > 2:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
|
|
|
if parameters.get('top_p', 1) < 0:
|
|
|
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
|
|
|
|
|
|
|
if parameters.get('max_tokens', 2) < 1:
|
|
|
|
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
|
|
|
|
|
|
|
|
|
|
|
|
def return_invalid_model_err(requested_model: str):
|
|
|
|
if requested_model:
|
|
|
|
msg = f"The model `{requested_model}` does not exist"
|
|
|
|
else:
|
|
|
|
msg = "The requested model does not exist"
|
2024-05-07 17:07:34 -06:00
|
|
|
return_oai_invalid_request_error(msg)
|
2024-05-07 17:03:41 -06:00
|
|
|
|
|
|
|
|
|
|
|
def return_oai_internal_server_error():
|
|
|
|
return jsonify({
|
|
|
|
"error": {
|
|
|
|
"message": "Internal server error",
|
|
|
|
"type": "auth_subrequest_error",
|
|
|
|
"param": None,
|
|
|
|
"code": "internal_error"
|
|
|
|
}
|
|
|
|
}), 500
|
2024-05-07 17:07:34 -06:00
|
|
|
|
|
|
|
|
|
|
|
def return_oai_invalid_request_error(msg: str = None):
|
|
|
|
return jsonify({
|
|
|
|
"error": {
|
|
|
|
"message": msg,
|
|
|
|
"type": "invalid_request_error",
|
|
|
|
"param": None,
|
|
|
|
"code": "model_not_found"
|
|
|
|
}
|
|
|
|
}), 404
|