From 26944872bc99af2566617f68a291101d802d09b6 Mon Sep 17 00:00:00 2001 From: shardy authors Date: Mon, 6 Jan 2025 07:12:53 -0800 Subject: [PATCH] Do not add explicit reshards for CholeskyOp. GSPMD fully-replicates cholesky-factorization dimensions before applying cholesky function. As a result, GSPMD partitioner adds extra all-gathers when an explicit reshards added (compared to when not-added) for the case that the operand is not sharded but the result is sharded on the cholesky-factorization dimensions. Hence, for the moment, turn off explicit reshard inserting for CholeskyOp. Comparing the output of GSPMD partitioner for: ``` func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) { %0 = stablehlo.cholesky %arg0 { sdy.sharding = [{"x"}, {"y"}], sdy.sharding_rule = ([i, j])->([i, j]) : tensor<4x4xf32> return %0 : tensor<4x4xf32> } -> tensor<4x4xf32> [{"x"}, {"y"}] ``` vs ``` func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) { %0 = sdy.reshard %arg0 [{"x"}, {"y"}] %1 = stablehlo.cholesky %0 { sdy.sharding = [{"x"}, {"y"}], sdy.sharding_rule = ([i, j])->([i, j]) : tensor<4x4xf32> return %1 : tensor<4x4xf32> } -> tensor<4x4xf32> [{"x"}, {"y"}] ``` GSPMD partitioner outputs the following diff: ``` ENTRY %main.5_spmd (param: f32[4,4]) -> f32[2,2] { %param = f32[4,4]{1,0} parameter(0), sharding={replicated} < %cholesky.0 = f32[4,4]{1,0} cholesky(f32[4,4]{1,0} %param) // Prepare for dynamic slicing from [4,4] to [2,2] ... < ROOT %dynamic-slice.2 = f32[2,2]{1,0} dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...), dynamic_slice_sizes={2,2}} > %dynamic-slice.2 = f32[2,2]{1,0} dynamic-slice(f32[4,4]{1,0} %param, s32[] ...), dynamic_slice_sizes={2,2} > %all-gather = f32[2,4]{1,0} all-gather(f32[2,2]{1,0} %dynamic-slice.2), channel_id=1, replica_groups={{0,1},{2,3}}, dimensions={1}, use_global_device_ids=true > %all-gather.1 = f32[4,4]{1,0} all-gather(f32[2,4]{1,0} %all-gather), channel_id=2, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true > %cholesky.0 = f32[4,4]{1,0} cholesky(f32[4,4]{1,0} %all-gather.1) > ROOT %dynamic-slice.5 = f32[2,2]{1,0} dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...), dynamic_slice_sizes={2,2} ``` PiperOrigin-RevId: 712518126 --- .../sdy/transforms/export/insert_explicit_reshards.cc | 8 ++++++++ .../transforms/export/test/insert_explicit_reshards.mlir | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index 04c5ba45..19ec5e59 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -532,6 +532,14 @@ struct InsertExplicitReshardsPass return; } + // GSPMD partitioner adds extra all-gathers when an explicit + // reshards added for the case that the operand is not sharded but the + // result is sharded on the cholesky-factorization dimensions. + // TODO(enver): Handle CholeskyOp. + if (isa(op)) { + return; + } + // Checks if factors are sharded the same way across operands and results. if (hasCompatibleFactorShardings(shardingProjection)) { return; diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir index eab0c60b..4668a986 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -450,3 +450,10 @@ func.func @dot_genaral_overlaps_and_trimmable_on_subaxis_multiple_axes(%arg0: te return %0 : tensor<64x8x16xf32> } +// CHECK-LABEL: func @cholesky +func.func @cholesky(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NOT: sdy.reshard + %0 = stablehlo.cholesky %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=4, j=4}>} : tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} +