Fix prefix caching + speculative decoding (#2711)
This commit is contained in:
parent
a5593ba83e
commit
aadc9cb485
|
@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
|
|||
fsm_grammar_states=fsm_grammar_states,
|
||||
)
|
||||
|
||||
speculative_ids = (
|
||||
torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
if batches[0].speculative_ids is not None
|
||||
else None
|
||||
)
|
||||
# We skip computing the speculative_ids when the batch size is too large, so
|
||||
# we must check that all batches have them, otherwise they must be discarded
|
||||
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
|
||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
if adapter_segment_builder is not None:
|
||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||
|
@ -1724,7 +1725,13 @@ class FlashCausalLM(Model):
|
|||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
|
||||
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
|
||||
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
|
||||
# allocated
|
||||
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
slots = batch.slots[slot_indices]
|
||||
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
|
Loading…
Reference in New Issue