Skip to content

Commit

Permalink
Add fixed moe from mixtral
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Huang <[email protected]>
  • Loading branch information
pi314ever committed Jan 27, 2025
1 parent a856d3d commit b693487
Showing 1 changed file with 19 additions and 34 deletions.
53 changes: 19 additions & 34 deletions optimum/habana/transformers/models/snowflake/modeling_arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,48 +754,33 @@ def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor:

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.top_k > 1:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
(batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# Matching between experts, tokens, and their top-k rank. For every i,
# expert_idx[i] is the rank topk_idx[i] expert for token_idx[i].
expert_idx, token_idx, topk_idx = torch.where(
selected_experts
== torch.arange(
self.num_experts,
device=selected_experts.device,
).view((self.num_experts, 1, 1))
padded_weights = torch.zeros(
(batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)

# Split into one chunk per expert.
bincount = torch.bincount(expert_idx, minlength=self.num_experts).tolist()
token_idx = token_idx.split(bincount)
topk_idx = topk_idx.split(bincount)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

# Loop over all available experts in the model and perform the computation on each expert
for expert_layer, top_x, idx in zip(self.experts, token_idx, topk_idx):
if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# torch.distributed.barrier()
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, hidden_dim)
current_hidden_states_static = (
expert_layer(current_state_static).reshape(-1, sequence_length, hidden_dim) * padded_weight
)
final_hidden_states += current_hidden_states_static
# support long sequences exceeding 8192
if not self.training and sequence_length > 8192:
htcore.mark_step()
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, load_balancing_loss_func(
(router_logits,), self.num_experts, self.top_k
Expand Down

0 comments on commit b693487

Please sign in to comment.