From b4cf832c40f1fc17401968198b739bdf8351d9f1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 27 Apr 2023 00:51:27 -0700 Subject: [PATCH] fix(server): fix reshaping of bloom past_key_values in concatenate() (#252) Introduced in #214 Fixes #249 --- server/text_generation_server/models/causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 336c982..ca8fccf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -335,7 +335,7 @@ class CausalLMBatch(Batch): [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values ] - elif batch.past_key_values[0][0].shape == 3: + elif len(batch.past_key_values[0][0].shape) == 3: for layer in batch.past_key_values: for k, t in enumerate(layer): layer[k] = t.view(len(batch), -1, *t.shape[-2:])