Fix prefix caching + speculative decoding

This commit is contained in:
Travis Addair 2024-10-31 13:43:49 -07:00 committed by GitHub
parent befd9f6735
commit 3bafa0eb7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 6 deletions

View File

@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states=fsm_grammar_states,
)
speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
speculative_ids = None
if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
@ -1718,7 +1719,13 @@ class FlashCausalLM(Model):
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
slots = batch.slots[slot_indices]
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)