diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index acd97f45..0d978920 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -798,11 +798,13 @@ class FlashCausalLM(Model): self.device, ) + self.compiled_model = torch.compile(self.model, mode="reduce-overhead") + if ENABLE_CUDA_GRAPHS: try: logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs - for bs in [1, 2, 4] + [8 * i for i in range(8)]: + for bs in [1]: if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except Exception: @@ -881,7 +883,19 @@ class FlashCausalLM(Model): or cuda_graph is None or batch.speculative_ids is not None ): - return self.model.forward( + if cu_seqlen_prefill is None: + return self.compiled_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, + ) + return self.model( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a2ac759a..6be1adb4 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashLlamaForCausalLM(config, weights) + # model = torch.compile(model, mode="reduce-overhead") torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8149c1b0..423ca206 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -495,18 +495,33 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - ) + + if cu_seqlen_prefill is None: + logits, speculative_logits = self.compiled_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) + else: + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, speculative_logits diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..9eabfe41 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -149,7 +149,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] concat_ns = None - generations, next_batch, timings = self.model.generate_token(batch) + torch.profiler._utils._init_for_cuda_graphs() + # prof = torch.profiler.profile() + # if self.model.rank != 0: + if True: + import contextlib + + prof = contextlib.nullcontext() + else: + prof = torch.profiler.profile() + with prof: + generations, next_batch, timings = self.model.generate_token(batch) + # if self.model.rank == 0: + # prof.export_chrome_trace(f"out_rank_0.json") self.cache.set(next_batch) return generate_pb2.DecodeResponse( diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 209f1c8a..7c9954a5 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -507,27 +507,27 @@ class TensorParallelHead(SuperLayer): return super().forward(input) world_size = self.process_group.size() - if len(input.shape) == 2 and isinstance(self.linear, FastLinear): - out_dim = self.linear.weight.shape[0] + # if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + # out_dim = self.linear.weight.shape[0] - if input.shape[0] == 1: - world_out = input.new_empty(1, out_dim * world_size) - local_out = input.new_empty(1, out_dim) - gather_input = local_out - else: - world_out = input.new_empty(out_dim * world_size, input.shape[0]) - gather_input = input.new_empty(out_dim, input.shape[0]) - local_out = gather_input.T + # if input.shape[0] == 1: + # world_out = input.new_empty(1, out_dim * world_size) + # local_out = input.new_empty(1, out_dim) + # gather_input = local_out + # else: + # world_out = input.new_empty(out_dim * world_size, input.shape[0]) + # gather_input = input.new_empty(out_dim, input.shape[0]) + # local_out = gather_input.T - torch.mm(input, self.linear.weight.T, out=local_out) + # torch.mm(input, self.linear.weight.T, out=local_out) - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + # torch.distributed.all_gather_into_tensor( + # world_out, gather_input, group=self.process_group + # ) - if input.shape[0] == 1: - return world_out - return world_out.T + # if input.shape[0] == 1: + # return world_out + # return world_out.T output = super().forward(input) world_output = [ @@ -786,6 +786,7 @@ try: self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None + self._update_cos_sin_cache(torch.float16, inv_freq.device, seqlen=4096) def forward( self, @@ -929,8 +930,6 @@ try: # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. dtype = torch.float32 - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.