fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)

Introduced in #214 

Fixes #249
This commit is contained in:
Nick Hill 2023-04-27 00:51:27 -07:00 committed by GitHub
parent db2b4e0754
commit b4cf832c40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -335,7 +335,7 @@ class CausalLMBatch(Batch):
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif batch.past_key_values[0][0].shape == 3:
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])