From b5f1c9de06ad00bbdeec0348c47f53bee271cedc Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 18:21:51 +0200 Subject: [PATCH] Fix TunableOp bug (#1920) cc @Narsil --- docs/source/basic_tutorials/monitoring.md | 2 +- .../models/flash_mistral.py | 22 +++++++++++++++++ server/text_generation_server/models/mamba.py | 24 +++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/docs/source/basic_tutorials/monitoring.md b/docs/source/basic_tutorials/monitoring.md index a24cf902..d6e50cfd 100644 --- a/docs/source/basic_tutorials/monitoring.md +++ b/docs/source/basic_tutorials/monitoring.md @@ -72,4 +72,4 @@ Once Prometheus data source is configured, we can finally create our dashboard! Community contributed dashboard templates are also available, for example [here](https://grafana.com/grafana/dashboards/19831-text-generation-inference-dashboard/) or [here](https://grafana.com/grafana/dashboards/20246-text-generation-inference/). -Load your dashboard configuration, and your TGI dashboard should be ready to go! \ No newline at end of file +Load your dashboard configuration, and your TGI dashboard should be ready to go! diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index b83f49a4..30ae95c9 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -391,6 +391,28 @@ class BaseFlashMistral(FlashCausalLM): def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch + def tunableop_warmup(self, seqlen: int): + input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) + kv_cache = get_cache_manager().kv_cache + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ), + kv_cache=get_cache_manager().kv_cache, + block_tables=None, + input_lengths=None, + slots=slots, + max_s=seqlen, + lm_head_indices=None, + prefill_cache_indices=None, + ) + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index b28b744f..d9f90590 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -522,6 +522,30 @@ class Mamba(Model): } self.cuda_graphs[batch_size] = graph_dict + def tunableop_warmup(self, seqlen: int): + input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) + n_blocks = len(self.model.blocks) + + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + # Inner takes the expand multiplication + d_inner = self.model.config.d_inner + + # Important seqlen_offset to go through the update mecanism with the state + seqlen_offset = 1 + inference_params = new_inference_params( + n_blocks=n_blocks, + batch_size=seqlen, + d_state=d_state, + d_conv=d_conv, + d_inner=d_inner, + seqlen_offset=seqlen_offset, + device=self.device, + dtype=self.dtype, + ) + + self.model.forward(input_ids=input_ids, inference_params=inference_params) + def forward( self, input_ids: torch.Tensor, inference_params: Any ) -> Tuple[torch.Tensor, torch.Tensor]: