diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 25c49183..706f3fdb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1723,9 +1723,11 @@ class FlashCausalLM(Model): # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, # then update the slots with the additional indices to ensure we're grabbing the ones that have been # allocated - slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + slot_indices = ( + batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) slots = batch.slots[slot_indices] - + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1)