diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index 04c5ba45..19ec5e59 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -532,6 +532,14 @@ struct InsertExplicitReshardsPass return; } + // GSPMD partitioner adds extra all-gathers when an explicit + // reshards added for the case that the operand is not sharded but the + // result is sharded on the cholesky-factorization dimensions. + // TODO(enver): Handle CholeskyOp. + if (isa(op)) { + return; + } + // Checks if factors are sharded the same way across operands and results. if (hasCompatibleFactorShardings(shardingProjection)) { return; diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir index eab0c60b..4668a986 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -450,3 +450,10 @@ func.func @dot_genaral_overlaps_and_trimmable_on_subaxis_multiple_axes(%arg0: te return %0 : tensor<64x8x16xf32> } +// CHECK-LABEL: func @cholesky +func.func @cholesky(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NOT: sdy.reshard + %0 = stablehlo.cholesky %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=4, j=4}>} : tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} +