add length penalty param to vllm

This commit is contained in:
Cyberes 2023-10-11 12:22:50 -06:00
parent 78114771b0
commit 1d1c45dc1a
2 changed files with 4 additions and 2 deletions

View File

@ -54,7 +54,8 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
def round_up_base(n, base): def round_up_base(n, base):
if base == 0: if base == 0:
print('round_up_base DIVIDE BY ZERO ERROR????', n, base) # TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
# print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
return 0 return 0
return math.ceil(n / base) * base return math.ceil(n / base) * base

View File

@ -39,7 +39,8 @@ class VLLMBackend(LLMBackend):
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']), max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']) frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty'])
) )
except ValueError as e: except ValueError as e:
return None, str(e).strip('.') return None, str(e).strip('.')