diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36f70180..cf9ac405 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1389,29 +1389,48 @@ class FlashCausalLM(Model): ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None input_lengths = [max_s] * bs cache_lengths = [0] * bs - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - cache_lengths=cache_lengths, - input_lengths_tensor=input_lengths_tensor, - cache_lengths_tensor=cache_lengths_tensor, - max_current_length=max_s, - ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, )