Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix if-clause #1325

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,9 +1248,8 @@ def forward(self, x, attention_mask, layer_past=None):
raise KeyError(self.moe_type)

with torch.enable_grad() if not self.eval else nullcontext():
if (
mlp_bias == None,
self.num_experts > 1 and self.moe_type == "deepspeed",
if mlp_bias == None or (
self.num_experts > 1 and self.moe_type == "deepspeed"
):
# No dropout either
assert mlp_bias is None
Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def calculate_derived(self):
else:
fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts"
assert self.precision == "fp16", fp16_conflict

if self.bf16 and self.bf16.get("enabled", False):
if self.precision is None:
self.update_value("precision", "bfloat16")
Expand Down
Loading