Fixing linting on main. (#2719)

This commit is contained in:
Nicolas Patry 2024-11-04 22:21:41 +08:00 committed by GitHub
parent aadc9cb485
commit 9fde566602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -1729,9 +1729,11 @@ class FlashCausalLM(Model):
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
slot_indices = (
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
slots = batch.slots[slot_indices]
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)