Fix incorrect cache allocation with multi-query (#2203)

We wouldn't allocate any memory in multi-query (1 KV head). Fixes
Starcoder et al.
This commit is contained in:
Daniël de Kok 2024-07-08 11:19:48 +02:00 committed by GitHub
parent cce475a949
commit 153fcf7739
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -912,7 +912,12 @@ class FlashCausalLM(Model):
break
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {}