@@ -873,25 +873,42 @@ def combine_micro_batches(micro_batches):
873873def replace_output_with_custom_grad (combined_output , custom_grad_output ):
874874 """
875875 Replace the main output tensor (logits, last_hidden_state, etc.) in the combined_output
876- with the custom_grad_output, preserving the original structure .
876+ with the custom_grad_output, preserving structure and returning a ModelOutput when possible .
877877 """
878- if hasattr (combined_output , "logits" ):
879- return combined_output .__class__ (
880- ** {** combined_output , "logits" : custom_grad_output }
881- )
882- elif hasattr (combined_output , "last_hidden_state" ):
883- return combined_output .__class__ (
884- ** {** combined_output , "last_hidden_state" : custom_grad_output }
885- )
886- elif isinstance (combined_output , torch .Tensor ):
878+ # If the combined output is already a tensor
879+ if isinstance (combined_output , torch .Tensor ):
887880 return custom_grad_output
888- else :
889- # For custom ModelOutput-like structures, replace the first tensor found
890- for key , value in combined_output .items ():
891- if isinstance (value , torch .Tensor ):
892- combined_output [key ] = custom_grad_output
893- break
894- return combined_output
881+
882+ # Handle ModelOutput subclasses (SequenceClassifierOutput, etc.)
883+ if isinstance (combined_output , ModelOutput ):
884+ data = combined_output .to_dict ()
885+ if "logits" in data :
886+ data ["logits" ] = custom_grad_output
887+ elif "last_hidden_state" in data :
888+ data ["last_hidden_state" ] = custom_grad_output
889+ else :
890+ for k , v in data .items ():
891+ if isinstance (v , torch .Tensor ):
892+ data [k ] = custom_grad_output
893+ break
894+ return combined_output .__class__ (** data )
895+
896+ # Handle dict outputs
897+ if isinstance (combined_output , dict ):
898+ new_output = dict (combined_output )
899+ if "logits" in new_output :
900+ new_output ["logits" ] = custom_grad_output
901+ elif "last_hidden_state" in new_output :
902+ new_output ["last_hidden_state" ] = custom_grad_output
903+ else :
904+ for k , v in new_output .items ():
905+ if isinstance (v , torch .Tensor ):
906+ new_output [k ] = custom_grad_output
907+ break
908+ # Wrap dict in a generic ModelOutput for consistency
909+ return ModelOutput (** new_output )
910+
911+ raise TypeError (f"Unsupported output type: { type (combined_output )} " )
895912
896913
897914def split_into_micro_batches (combined_output , n_micro_batch ):
0 commit comments