From 9c839ca5df1cabb2a441d7ecdabb8ce900b567c3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 28 Aug 2024 21:22:36 +0200 Subject: [PATCH] Adding error message when assert is violated. --- .../text_generation_server/models/flash_causal_lm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a84ba765..9a60d06c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -262,17 +262,12 @@ class FlashCausalLMBatch(Batch): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - # tokenized_input = tokenized_input[-r.truncate :] - # if ( - # tokenized_input[0] == tokenizer.bos_token_id - # and tokenized_input[1] == tokenizer.bos_token_id - # ): - # tokenized_input = tokenized_input[1:] - orig_input_length = len(tokenized_input) prefix_len = r.prefix_len - assert prefix_len <= orig_input_length + assert ( + prefix_len <= orig_input_length + ), f"Prefix {prefix_len} vs input {orig_input_length}" if prefix_len == orig_input_length: assert prefix_len > 0 prefix_len -= 1