Skip to content

Fix CrossEntropyLoss block to support multi-output models#28232

Open
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/onnxblock-ce-loss-multi-output
Open

Fix CrossEntropyLoss block to support multi-output models#28232
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/onnxblock-ce-loss-multi-output

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • artifacts.generate_artifacts(..., loss=LossType.CrossEntropyLoss) no longer aborts with i < node_->OutputDefs().size() when the base model has multi-dimensional outputs.
  • The SoftmaxCrossEntropyLoss op produces two outputs (loss, log_prob); the second was being dropped by graph optimizers because it had no value_info entry, leaving the gradient builder to dereference a missing output def via O(1).

Motivation

Fixes #22465. Users hit a hard C++ assertion when training models like DistilBERT whose forward graph emits a multi-dimensional last-hidden-state tensor. The same pattern appears for any seq2seq / LM training setup that pipes a 3-D output into CrossEntropyLoss.

This is a Python-only change scoped to the onnxblock training-artifacts API; the core inference engine is unaffected.

Changes

  • orttraining/orttraining/python/training/onnxblock/loss/loss.py — after appending the SoftmaxCrossEntropyLoss node, register a value_info entry for log_prob_output_name so its output def survives shape inference and graph cleanup. Idempotent — guarded against duplicate entries.
  • orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py — new test_crossentropy_loss_multi_output_model builds a 3-D output toy model, calls generate_artifacts with LossType.CrossEntropyLoss, and asserts the saved training_model.onnx retains both outputs on the SCE node.

Test Plan

  • New test exercises the previously-failing path:
    python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py::test_crossentropy_loss_multi_output_model -v
  • Existing CE coverage:
    python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py -k crossentropy -v
  • lintrunner clean on the diff.

Fixes #22465

…put models

CrossEntropyLoss.build() created a SoftmaxCrossEntropyLoss node with two
outputs (loss, log_prob) but never registered log_prob in model.graph.value_info.
Graph optimizers then dropped the output def, causing the gradient builder to
hit a C++ assertion (i < node_->OutputDefs().size()) via O(1) when generating
training artifacts for models with multi-dimensional outputs (e.g. seq2seq).

Fix: after appending the node, add a value_info entry for log_prob_output_name
using the same elem_type as the input scores tensor. A guard prevents duplicate
entries if build() is called more than once. This keeps the output def alive
through graph cleanup without changing the user-visible API (the block still
returns only loss_node_output_name).

Fixes microsoft#22465
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Training] Error building gradient graph for bert models for on-device training

1 participant