Fixing auto bloom test. (#2699)
This commit is contained in:
parent
78ce618c70
commit
3a9cdc3241
|
@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
|
||||||
past_value.view(-1, *past_value.shape[-2:]),
|
past_value.view(-1, *past_value.shape[-2:]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if CUSTOM_KERNELS_ENABLED:
|
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
|
||||||
assert self.training is False, "Only foward pass was implemented"
|
assert self.training is False, "Only foward pass was implemented"
|
||||||
assert (
|
assert (
|
||||||
attention_mask.shape[-1] < 4096
|
attention_mask.shape[-1] < 4096
|
||||||
|
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_to_bloom_cache(
|
def _convert_to_bloom_cache(
|
||||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||||
|
|
Loading…
Reference in New Issue