diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8ab1a811..25c49183 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states=fsm_grammar_states, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None - else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + speculative_ids = None if adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -1718,7 +1719,13 @@ class FlashCausalLM(Model): new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + slots = batch.slots[slot_indices] + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1)