Fixing auto bloom test. (#2699)
This commit is contained in:
parent
78ce618c70
commit
3a9cdc3241
|
@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
|
|||
past_value.view(-1, *past_value.shape[-2:]),
|
||||
)
|
||||
|
||||
if CUSTOM_KERNELS_ENABLED:
|
||||
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
|
||||
assert self.training is False, "Only foward pass was implemented"
|
||||
assert (
|
||||
attention_mask.shape[-1] < 4096
|
||||
|
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
|||
|
||||
@staticmethod
|
||||
def _convert_to_bloom_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
|
|
Loading…
Reference in New Issue