Adding error message when assert is violated.

This commit is contained in:
Nicolas Patry 2024-08-28 21:22:36 +02:00
parent e7e036389e
commit 9c839ca5df
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
1 changed files with 3 additions and 8 deletions

View File

@ -262,17 +262,12 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
# tokenized_input = tokenized_input[-r.truncate :]
# if (
# tokenized_input[0] == tokenizer.bos_token_id
# and tokenized_input[1] == tokenizer.bos_token_id
# ):
# tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len
assert prefix_len <= orig_input_length
assert (
prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1