98 lines
3.8 KiB
Python
98 lines
3.8 KiB
Python
|
from flask import jsonify
|
||
|
|
||
|
from llm_server import opts
|
||
|
|
||
|
|
||
|
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||
|
if not request_json_body.get('stop'):
|
||
|
request_json_body['stop'] = []
|
||
|
if not isinstance(request_json_body['stop'], list):
|
||
|
# It is a string, so create a list with the existing element.
|
||
|
request_json_body['stop'] = [request_json_body['stop']]
|
||
|
|
||
|
if stop_hashes:
|
||
|
if opts.openai_force_no_hashes:
|
||
|
request_json_body['stop'].append('###')
|
||
|
else:
|
||
|
# TODO: make stopping strings a configurable
|
||
|
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
|
||
|
else:
|
||
|
request_json_body['stop'].extend(['user:', 'assistant:'])
|
||
|
|
||
|
if request_json_body.get('frequency_penalty', 0) < -2:
|
||
|
request_json_body['frequency_penalty'] = -2
|
||
|
elif request_json_body.get('frequency_penalty', 0) > 2:
|
||
|
request_json_body['frequency_penalty'] = 2
|
||
|
|
||
|
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
||
|
request_json_body['top_p'] = 0.01
|
||
|
|
||
|
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
||
|
if request_json_body['max_tokens'] == 0:
|
||
|
# We don't want to set any defaults here.
|
||
|
del request_json_body['max_tokens']
|
||
|
|
||
|
return request_json_body
|
||
|
|
||
|
|
||
|
def format_oai_err(err_msg):
|
||
|
print('OAI ERROR MESSAGE:', err_msg)
|
||
|
return jsonify({
|
||
|
"error": {
|
||
|
"message": err_msg,
|
||
|
"type": "invalid_request_error",
|
||
|
"param": None,
|
||
|
"code": None
|
||
|
}
|
||
|
}), 400
|
||
|
|
||
|
|
||
|
def validate_oai(parameters):
|
||
|
if parameters.get('messages'):
|
||
|
for m in parameters['messages']:
|
||
|
if m['role'].lower() not in ['assistant', 'user', 'system']:
|
||
|
return format_oai_err('messages role must be assistant, user, or system')
|
||
|
|
||
|
if parameters.get('temperature', 0) > 2:
|
||
|
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
||
|
if parameters.get('temperature', 0) < 0:
|
||
|
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
|
||
|
|
||
|
if parameters.get('top_p', 1) > 2:
|
||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||
|
if parameters.get('top_p', 1) < 0:
|
||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||
|
|
||
|
if parameters.get('presence_penalty', 1) > 2:
|
||
|
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
|
||
|
if parameters.get('presence_penalty', 1) < -2:
|
||
|
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
|
||
|
|
||
|
if parameters.get('top_p', 1) > 2:
|
||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||
|
if parameters.get('top_p', 1) < 0:
|
||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||
|
|
||
|
if parameters.get('top_p', 1) > 2:
|
||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||
|
if parameters.get('top_p', 1) < 0:
|
||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||
|
|
||
|
if parameters.get('max_tokens', 2) < 1:
|
||
|
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
|
||
|
|
||
|
|
||
|
def return_invalid_model_err(requested_model: str):
|
||
|
if requested_model:
|
||
|
msg = f"The model `{requested_model}` does not exist"
|
||
|
else:
|
||
|
msg = "The requested model does not exist"
|
||
|
return jsonify({
|
||
|
"error": {
|
||
|
"message": msg,
|
||
|
"type": "invalid_request_error",
|
||
|
"param": None,
|
||
|
"code": "model_not_found"
|
||
|
}
|
||
|
}), 404
|