Fix prefix caching + speculative decoding
This commit is contained in:
parent
befd9f6735
commit
3bafa0eb7b
|
@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
fsm_grammar_states=fsm_grammar_states,
|
fsm_grammar_states=fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
# We skip computing the speculative_ids when the batch size is too large, so
|
||||||
torch.cat([b.speculative_ids for b in batches], dim=0)
|
# we must check that all batches have them, otherwise they must be discarded
|
||||||
if batches[0].speculative_ids is not None
|
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
|
||||||
else None
|
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||||
)
|
else:
|
||||||
|
speculative_ids = None
|
||||||
|
|
||||||
if adapter_segment_builder is not None:
|
if adapter_segment_builder is not None:
|
||||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||||
|
@ -1718,7 +1719,13 @@ class FlashCausalLM(Model):
|
||||||
new_position_ids = (
|
new_position_ids = (
|
||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
).view(-1)
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
||||||
|
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
|
||||||
|
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
|
||||||
|
# allocated
|
||||||
|
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
slots = batch.slots[slot_indices]
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
|
Loading…
Reference in New Issue