Adding error message when assert is violated.
This commit is contained in:
parent
e7e036389e
commit
9c839ca5df
|
@ -262,17 +262,12 @@ class FlashCausalLMBatch(Batch):
|
|||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
# tokenized_input = tokenized_input[-r.truncate :]
|
||||
# if (
|
||||
# tokenized_input[0] == tokenizer.bos_token_id
|
||||
# and tokenized_input[1] == tokenizer.bos_token_id
|
||||
# ):
|
||||
# tokenized_input = tokenized_input[1:]
|
||||
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
prefix_len = r.prefix_len
|
||||
assert prefix_len <= orig_input_length
|
||||
assert (
|
||||
prefix_len <= orig_input_length
|
||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
|
|
Loading…
Reference in New Issue