Fixing graph capture for flash decoding. (#2163)
This commit is contained in:
parent
4327210e6b
commit
022f6515a4
|
@ -926,7 +926,7 @@ class FlashCausalLM(Model):
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths,
|
||||||
}
|
}
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
@ -939,7 +939,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -947,6 +947,7 @@ class FlashCausalLM(Model):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|
Loading…
Reference in New Issue