Some MoE exploration

This commit is contained in:
Daniël de Kok 2024-07-15 11:47:52 +00:00
parent dbb23fbfa8
commit f6ad3b3585
2 changed files with 39 additions and 17 deletions

View File

@ -1465,7 +1465,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
//TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
@ -1583,7 +1583,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
int thread_n, int sms, int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
TORCH_CHECK(prob_m >= 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
int tot_m = prob_m;

View File

@ -442,29 +442,51 @@ class DenseMoE(nn.Module):
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# (sequence_length, n_experts_per_tok)
routing_weights, selected_experts = torch.topk(all_probs, self.top_k, dim=1)
routing_weights /= routing_weights.sum(dim=1, keepdim=True)
routing_weights.to(x.dtype)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
# logger.info(
# f"routing_weights: {routing_weights.shape}, selected_experts: {selected_experts.shape}"
# )
# logger.info(
# f"expert mask before permute: {torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape}",
# )
# expert_mask: (n_experts, sequence_length, n_experts_per_tok)
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
# logger.info(f"expert mask shape: {expert_mask.shape}")
# logger.info(f"expert mask: {expert_mask}")
indices = torch.empty(
(x.shape[0] * self.top_k, 2), dtype=torch.long, device=x.device
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# logger.info(f"indices shape: {indices.shape}")
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
# idx, top_x = torch.where(expert_mask[i])
torch.nonzero(expert_mask[i], out=indices)
idx = indices.t()[0]
top_x = indices.t()[1]
h = x[None, top_x].reshape(-1, self.hidden_dim)
# Sometimes an expert is not used for any tokens. However, some matmul
# kernels do not support empty batches (e.g. Marlin).
# if h.shape[0] == 0:
# continue
h = self.act(self.w1[i](h)) * self.w3[i](h)
h = self.w2[i](h, reduce=False)
# logger.info(f"top_x: {top_x}, idx: {idx}")
# logger.info(f"routing weights shape: {routing_weights.shape}")
h *= routing_weights[top_x, idx, None]
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
out.index_add_(0, top_x, h)
# Reduce sum
if self.process_group.size() > 1: