-
Notifications
You must be signed in to change notification settings - Fork 633
[Torch] Fold aten.to.dtype
on splat constants.
#4306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Not sure who can review, maybe you would know @vivekkhandelwal1 @zjgarvey ? |
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions when the operand is a splat `DenseElementsAttr`. Folding is done according to torch's rounding behavior, i.e. * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true. * Float → Int: round toward zero. * Int → Float: sign-aware, rmNearestTiesToEven. * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`. * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness. Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
3bf4e4b
to
1d7b55b
Compare
// int32 splat → float32 | ||
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32> | ||
%int6 = torch.constant.int 6 // torch.float32 | ||
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put the actual value which I think here will be 42.0
?
-> !torch.vtensor<[4,4],si32> | ||
|
||
// int64 splat (max int32) → int32 (trunc) | ||
%int64_splat = torch.vtensor.literal(dense<2147483647> : tensor<10xsi64>) : !torch.vtensor<[10],si64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this value be int32max+1
to ensure that trucation does happen in the IR being locked down?
// float32 splat → float64 | ||
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32> | ||
%int7 = torch.constant.int 7 // torch.float64 | ||
// CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's capture the actual value here too and other such places.
// int32 splat → float32 | ||
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32> | ||
%int6 = torch.constant.int 6 // torch.float32 | ||
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are not locking the values being returned from the output IR, I think we should add CHECK-NOT:torch.aten.to.dtype
as well to ensure that the op is being folded.
@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch | |||
return %0 : !torch.tensor | |||
} | |||
|
|||
// CHECK-LABEL: @torch.aten.to.dtype$fold_splat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some e2e tests to ensure that torch's rounding logic is accurately captured in this implementation?
Also please fix the CI failures, we cannot merge until CI pipelines are green.
This commit teaches
AtenToDtypeOp::fold
to constant-fold dtype conversions when the operand is a splatDenseElementsAttr
.Folding is done according to torch's rounding behavior, i.e.
mlir::FloatType::getFloatSemantics()
.zextOrTrunc
/sextOrTrunc
based on source signedness.Folding is only performed when
non_blocking == false
,copy == false
, andmemory_format
is None.