Some MoE exploration
This commit is contained in:
parent
dbb23fbfa8
commit
f6ad3b3585
|
@ -1465,7 +1465,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
||||||
|
|
||||||
max_m_blocks--;
|
max_m_blocks--;
|
||||||
if (max_m_blocks == 0) {
|
if (max_m_blocks == 0) {
|
||||||
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
//TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1583,7 +1583,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||||
int thread_n, int sms, int max_par) {
|
int thread_n, int sms, int max_par) {
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m >= 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
", ", prob_n, ", ", prob_k, "]");
|
", ", prob_n, ", ", prob_k, "]");
|
||||||
|
|
||||||
int tot_m = prob_m;
|
int tot_m = prob_m;
|
||||||
|
|
|
@ -442,29 +442,51 @@ class DenseMoE(nn.Module):
|
||||||
gate_logits = self.gate(x)
|
gate_logits = self.gate(x)
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
# (sequence_length, n_experts_per_tok)
|
||||||
|
routing_weights, selected_experts = torch.topk(all_probs, self.top_k, dim=1)
|
||||||
|
routing_weights /= routing_weights.sum(dim=1, keepdim=True)
|
||||||
|
routing_weights.to(x.dtype)
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
# logger.info(
|
||||||
_, not_selected_experts = torch.topk(
|
# f"routing_weights: {routing_weights.shape}, selected_experts: {selected_experts.shape}"
|
||||||
all_probs,
|
# )
|
||||||
self.num_experts - self.top_k,
|
|
||||||
largest=False,
|
|
||||||
sorted=False,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
all_probs.scatter_(1, not_selected_experts, 0)
|
|
||||||
|
|
||||||
# Re-normalize
|
# logger.info(
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
# f"expert mask before permute: {torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape}",
|
||||||
weights = weights.to(x.dtype)
|
# )
|
||||||
|
|
||||||
|
# expert_mask: (n_experts, sequence_length, n_experts_per_tok)
|
||||||
|
expert_mask = torch.nn.functional.one_hot(
|
||||||
|
selected_experts, num_classes=self.num_experts
|
||||||
|
).permute(2, 1, 0)
|
||||||
|
# logger.info(f"expert mask shape: {expert_mask.shape}")
|
||||||
|
# logger.info(f"expert mask: {expert_mask}")
|
||||||
|
|
||||||
|
indices = torch.empty(
|
||||||
|
(x.shape[0] * self.top_k, 2), dtype=torch.long, device=x.device
|
||||||
|
)
|
||||||
|
# logger.info(f"indices shape: {indices.shape}")
|
||||||
# Final output tensor
|
# Final output tensor
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
# idx, top_x = torch.where(expert_mask[i])
|
||||||
|
torch.nonzero(expert_mask[i], out=indices)
|
||||||
|
idx = indices.t()[0]
|
||||||
|
top_x = indices.t()[1]
|
||||||
|
h = x[None, top_x].reshape(-1, self.hidden_dim)
|
||||||
|
|
||||||
|
# Sometimes an expert is not used for any tokens. However, some matmul
|
||||||
|
# kernels do not support empty batches (e.g. Marlin).
|
||||||
|
# if h.shape[0] == 0:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
h = self.act(self.w1[i](h)) * self.w3[i](h)
|
||||||
h = self.w2[i](h, reduce=False)
|
h = self.w2[i](h, reduce=False)
|
||||||
|
# logger.info(f"top_x: {top_x}, idx: {idx}")
|
||||||
|
# logger.info(f"routing weights shape: {routing_weights.shape}")
|
||||||
|
h *= routing_weights[top_x, idx, None]
|
||||||
# Add expert output to out with masking
|
# Add expert output to out with masking
|
||||||
out += h * weights[:, i].view(-1, 1)
|
out.index_add_(0, top_x, h)
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
|
Loading…
Reference in New Issue