Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions orttraining/orttraining/python/training/onnxblock/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def build(self, scores_input_name: str, labels_name: str = "labels"):
)
self.base.graph.node.append(loss_node)

# Register log_prob in value_info so the gradient builder can resolve
# O(1) (the second output of SoftmaxCrossEntropyLoss). Without this,
# graph optimizers may drop the output def and cause a C++ assertion.
scores_info = _graph_utils.get_output_from_output_name(self.base, scores_input_name)
scores_elem_type = scores_info.type.tensor_type.elem_type
if not any(vi.name == log_prob_output_name for vi in self.base.graph.value_info):
self.base.graph.value_info.append(
onnx.helper.make_tensor_value_info(log_prob_output_name, scores_elem_type, None)
)

return loss_node_output_name


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1187,3 +1187,60 @@ def test_generate_artifacts_external_data_separate_files(loss):
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_crossentropy_loss_multi_output_model():
"""Regression test for https://github.com/microsoft/onnxruntime/issues/22465.

generate_artifacts with CrossEntropyLoss must not raise a C++ assertion
(i < node_->OutputDefs().size()) when the base model has a multi-dimensional
output (e.g. shape [batch, seq_len, vocab_size]). The root cause was that
CrossEntropyLoss.build() constructed a SoftmaxCrossEntropyLoss node with two
outputs (loss, log_prob) but never registered log_prob in value_info, causing
the gradient builder to dereference a missing output def.
"""

class MultiOutputNet(torch.nn.Module):
"""Simple seq2seq-style model whose output is 3-D (batch, seq, vocab)."""

def __init__(self, vocab_size: int = 16, hidden_size: int = 8):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.linear = torch.nn.Linear(hidden_size, vocab_size)

def forward(self, x):
# x: (batch, seq_len) -> output: (batch, seq_len, vocab_size)
return self.linear(self.embed(x))

vocab_size = 16
batch_size = 2
seq_len = 4
model = MultiOutputNet(vocab_size=vocab_size)
model.eval()

x = torch.randint(0, vocab_size, (batch_size, seq_len))
onnx_model = _get_onnx_model(model, (x,))

requires_grad_params = ["embed.weight", "linear.weight", "linear.bias"]

with tempfile.TemporaryDirectory() as temp_dir:
# This must not raise "i < node_->OutputDefs().size()" or any other error.
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
artifact_directory=temp_dir,
)

# Verify the training model was created.
training_model_path = os.path.join(temp_dir, "training_model.onnx")
assert os.path.exists(training_model_path)

# Verify the SoftmaxCrossEntropyLoss node retains both output defs
# (loss and log_prob) in the saved training model.
training_model = onnx.load(training_model_path)
sce_nodes = [n for n in training_model.graph.node if n.op_type == "SoftmaxCrossEntropyLoss"]
assert len(sce_nodes) == 1, "Expected exactly one SoftmaxCrossEntropyLoss node"
assert len(sce_nodes[0].output) == 2, (
f"SoftmaxCrossEntropyLoss node must have 2 outputs (loss, log_prob), got {list(sce_nodes[0].output)}"
)