From f6ad3b35854e81002fa0bdc135ca51f0e2f5531a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 15 Jul 2024 11:47:52 +0000 Subject: [PATCH] Some MoE exploration --- server/marlin/marlin_kernels/gptq_marlin.cu | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 52 +++++++++++++------ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu index 0beb9de1..932ff2b7 100644 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ b/server/marlin/marlin_kernels/gptq_marlin.cu @@ -1465,7 +1465,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, max_m_blocks--; if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + //TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); } } @@ -1583,7 +1583,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int thread_n, int sms, int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + TORCH_CHECK(prob_m >= 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int tot_m = prob_m; diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 49c0e903..96c865a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -442,29 +442,51 @@ class DenseMoE(nn.Module): gate_logits = self.gate(x) # all_probs: (sequence_length, n_experts) and upcast for softmax all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + # (sequence_length, n_experts_per_tok) + routing_weights, selected_experts = torch.topk(all_probs, self.top_k, dim=1) + routing_weights /= routing_weights.sum(dim=1, keepdim=True) + routing_weights.to(x.dtype) - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - all_probs, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - all_probs.scatter_(1, not_selected_experts, 0) + # logger.info( + # f"routing_weights: {routing_weights.shape}, selected_experts: {selected_experts.shape}" + # ) - # Re-normalize - weights = all_probs / all_probs.sum(dim=1, keepdim=True) - weights = weights.to(x.dtype) + # logger.info( + # f"expert mask before permute: {torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape}", + # ) + # expert_mask: (n_experts, sequence_length, n_experts_per_tok) + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + # logger.info(f"expert mask shape: {expert_mask.shape}") + # logger.info(f"expert mask: {expert_mask}") + + indices = torch.empty( + (x.shape[0] * self.top_k, 2), dtype=torch.long, device=x.device + ) + # logger.info(f"indices shape: {indices.shape}") # Final output tensor out = x.new_zeros(x.shape[0], self.hidden_dim) for i in range(self.num_experts): - h = self.act(self.w1[i](x)) * self.w3[i](x) + # idx, top_x = torch.where(expert_mask[i]) + torch.nonzero(expert_mask[i], out=indices) + idx = indices.t()[0] + top_x = indices.t()[1] + h = x[None, top_x].reshape(-1, self.hidden_dim) + + # Sometimes an expert is not used for any tokens. However, some matmul + # kernels do not support empty batches (e.g. Marlin). + # if h.shape[0] == 0: + # continue + + h = self.act(self.w1[i](h)) * self.w3[i](h) h = self.w2[i](h, reduce=False) + # logger.info(f"top_x: {top_x}, idx: {idx}") + # logger.info(f"routing weights shape: {routing_weights.shape}") + h *= routing_weights[top_x, idx, None] # Add expert output to out with masking - out += h * weights[:, i].view(-1, 1) + out.index_add_(0, top_x, h) # Reduce sum if self.process_group.size() > 1: