Fixing auto bloom test. (#2699)

This commit is contained in:
Nicolas Patry 2024-10-28 06:14:11 +01:00 committed by GitHub
parent 78ce618c70
commit 3a9cdc3241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
past_value.view(-1, *past_value.shape[-2:]), past_value.view(-1, *past_value.shape[-2:]),
) )
if CUSTOM_KERNELS_ENABLED: if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
assert self.training is False, "Only foward pass was implemented" assert self.training is False, "Only foward pass was implemented"
assert ( assert (
attention_mask.shape[-1] < 4096 attention_mask.shape[-1] < 4096
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod @staticmethod
def _convert_to_bloom_cache( def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
""" """
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))