Fix incorrect cache allocation with multi-query (#2203)
We wouldn't allocate any memory in multi-query (1 KV head). Fixes Starcoder et al.
This commit is contained in:
parent
cce475a949
commit
153fcf7739
|
@ -912,7 +912,12 @@ class FlashCausalLM(Model):
|
||||||
break
|
break
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
raise ValueError("Cannot get the number of key/value heads")
|
raise ValueError("Cannot get the number of key/value heads")
|
||||||
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
self.num_kv_heads = (
|
||||||
|
num_kv_heads // self.process_group.size()
|
||||||
|
if num_kv_heads > 1
|
||||||
|
else num_kv_heads
|
||||||
|
)
|
||||||
|
assert self.num_kv_heads > 0
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
|
|
Loading…
Reference in New Issue