Skip to content

Conversation

mdazz
Copy link

@mdazz mdazz commented Sep 5, 2025

This commit teaches AtenToDtypeOp::fold to constant-fold dtype conversions when the operand is a splat DenseElementsAttr.

Folding is done according to torch's rounding behavior, i.e.

  • Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true.
  • Float → Int: round toward zero.
  • Int → Float: sign-aware, rmNearestTiesToEven.
  • Float ↔ Float: use builtin mlir::FloatType::getFloatSemantics().
  • Int ↔ Int: use zextOrTrunc / sextOrTrunc based on source signedness.

Folding is only performed when non_blocking == false, copy == false, and memory_format is None.

@mdazz
Copy link
Author

mdazz commented Sep 5, 2025

Not sure who can review, maybe you would know @vivekkhandelwal1 @zjgarvey ?

This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions
when the operand is a splat `DenseElementsAttr`.

Folding is done according to torch's rounding behavior, i.e.
  * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true.
  * Float → Int: round toward zero.
  * Int → Float: sign-aware, rmNearestTiesToEven.
  * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`.
  * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness.

Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
// int32 splat → float32
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
%int6 = torch.constant.int 6 // torch.float32
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put the actual value which I think here will be 42.0 ?

-> !torch.vtensor<[4,4],si32>

// int64 splat (max int32) → int32 (trunc)
%int64_splat = torch.vtensor.literal(dense<2147483647> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this value be int32max+1 to ensure that trucation does happen in the IR being locked down?

// float32 splat → float64
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32>
%int7 = torch.constant.int 7 // torch.float64
// CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's capture the actual value here too and other such places.

// int32 splat → float32
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
%int6 = torch.constant.int 6 // torch.float32
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not locking the values being returned from the output IR, I think we should add CHECK-NOT:torch.aten.to.dtype as well to ensure that the op is being folded.

@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
return %0 : !torch.tensor
}

// CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some e2e tests to ensure that torch's rounding logic is accurately captured in this implementation?
Also please fix the CI failures, we cannot merge until CI pipelines are green.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants