From 022f6515a48c26cd505b4bbfa8da3bd00e77f078 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 11:43:07 +0200 Subject: [PATCH] Fixing graph capture for flash decoding. (#2163) --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 49a088a1..4f276ed4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -926,7 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -939,7 +939,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -947,6 +947,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,