From 3a9cdc324100d567cb28f15823e3be010fe284be Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 28 Oct 2024 06:14:11 +0100 Subject: [PATCH] Fixing auto bloom test. (#2699) --- .../models/custom_modeling/bloom_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e2719fad..84835ab8 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -377,7 +377,7 @@ class BloomAttention(nn.Module): past_value.view(-1, *past_value.shape[-2:]), ) - if CUSTOM_KERNELS_ENABLED: + if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096: assert self.training is False, "Only foward pass was implemented" assert ( attention_mask.shape[-1] < 4096 @@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel): @staticmethod def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))