Fixing graph capture for flash decoding. (#2163)
This commit is contained in:
parent
4327210e6b
commit
022f6515a4
|
@ -926,7 +926,7 @@ class FlashCausalLM(Model):
|
|||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
|
@ -939,7 +939,7 @@ class FlashCausalLM(Model):
|
|||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths_,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
|
@ -947,6 +947,7 @@ class FlashCausalLM(Model):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
|
Loading…
Reference in New Issue