diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 336c982..ca8fccf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -335,7 +335,7 @@ class CausalLMBatch(Batch): [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values ] - elif batch.past_key_values[0][0].shape == 3: + elif len(batch.past_key_values[0][0].shape) == 3: for layer in batch.past_key_values: for k, t in enumerate(layer): layer[k] = t.view(len(batch), -1, *t.shape[-2:])