Some MoE exploration
This commit is contained in:
parent
dbb23fbfa8
commit
f6ad3b3585
|
@ -1465,7 +1465,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
|||
|
||||
max_m_blocks--;
|
||||
if (max_m_blocks == 0) {
|
||||
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
||||
//TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1583,7 +1583,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
|||
int thread_n, int sms, int max_par) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
TORCH_CHECK(prob_m >= 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int tot_m = prob_m;
|
||||
|
|
|
@ -442,29 +442,51 @@ class DenseMoE(nn.Module):
|
|||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
# (sequence_length, n_experts_per_tok)
|
||||
routing_weights, selected_experts = torch.topk(all_probs, self.top_k, dim=1)
|
||||
routing_weights /= routing_weights.sum(dim=1, keepdim=True)
|
||||
routing_weights.to(x.dtype)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
# logger.info(
|
||||
# f"routing_weights: {routing_weights.shape}, selected_experts: {selected_experts.shape}"
|
||||
# )
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
# logger.info(
|
||||
# f"expert mask before permute: {torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape}",
|
||||
# )
|
||||
|
||||
# expert_mask: (n_experts, sequence_length, n_experts_per_tok)
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
).permute(2, 1, 0)
|
||||
# logger.info(f"expert mask shape: {expert_mask.shape}")
|
||||
# logger.info(f"expert mask: {expert_mask}")
|
||||
|
||||
indices = torch.empty(
|
||||
(x.shape[0] * self.top_k, 2), dtype=torch.long, device=x.device
|
||||
)
|
||||
# logger.info(f"indices shape: {indices.shape}")
|
||||
# Final output tensor
|
||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||
for i in range(self.num_experts):
|
||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
||||
# idx, top_x = torch.where(expert_mask[i])
|
||||
torch.nonzero(expert_mask[i], out=indices)
|
||||
idx = indices.t()[0]
|
||||
top_x = indices.t()[1]
|
||||
h = x[None, top_x].reshape(-1, self.hidden_dim)
|
||||
|
||||
# Sometimes an expert is not used for any tokens. However, some matmul
|
||||
# kernels do not support empty batches (e.g. Marlin).
|
||||
# if h.shape[0] == 0:
|
||||
# continue
|
||||
|
||||
h = self.act(self.w1[i](h)) * self.w3[i](h)
|
||||
h = self.w2[i](h, reduce=False)
|
||||
# logger.info(f"top_x: {top_x}, idx: {idx}")
|
||||
# logger.info(f"routing weights shape: {routing_weights.shape}")
|
||||
h *= routing_weights[top_x, idx, None]
|
||||
# Add expert output to out with masking
|
||||
out += h * weights[:, i].view(-1, 1)
|
||||
out.index_add_(0, top_x, h)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
|
|
Loading…
Reference in New Issue