Fixing graph capture for flash decoding. (#2163)

This commit is contained in:
Nicolas Patry 2024-07-02 11:43:07 +02:00 committed by GitHub
parent 4327210e6b
commit 022f6515a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -926,7 +926,7 @@ class FlashCausalLM(Model):
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths,
} }
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths_ = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -939,7 +939,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths_,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -947,6 +947,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,