Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not add explicit reshards for CholeskyOp. #289

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

Do not add explicit reshards for CholeskyOp.

GSPMD fully-replicates cholesky-factorization dimensions before applying cholesky function.

As a result, GSPMD partitioner adds extra all-gathers when an explicit reshards added (compared to when not-added) for the case that the operand is not sharded but the result is sharded on the cholesky-factorization dimensions.

Hence, for the moment, turn off explicit reshard inserting for CholeskyOp.

Comparing the output of GSPMD partitioner for:

func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) {
  %0 = stablehlo.cholesky %arg0 {
    sdy.sharding = [{"x"}, {"y"}], 
    sdy.sharding_rule = ([i, j])->([i, j]) 
  : tensor<4x4xf32>
  return %0 : tensor<4x4xf32>
} -> tensor<4x4xf32> [{"x"}, {"y"}]

vs

func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) {
  %0 = sdy.reshard %arg0 [{"x"}, {"y"}]
  %1 = stablehlo.cholesky %0 {
    sdy.sharding = [{"x"}, {"y"}], 
    sdy.sharding_rule = ([i, j])->([i, j]) 
  : tensor<4x4xf32>
  return %1 : tensor<4x4xf32>
} -> tensor<4x4xf32> [{"x"}, {"y"}]

GSPMD partitioner outputs the following diff:

ENTRY %main.5_spmd (param: f32[4,4]) -> f32[2,2] {
  %param = f32[4,4]{1,0} parameter(0), sharding={replicated}
  < %cholesky.0 = f32[4,4]{1,0} 
      cholesky(f32[4,4]{1,0} %param)

  // Prepare for dynamic slicing from [4,4] to [2,2]
  ... 

  < ROOT %dynamic-slice.2 = f32[2,2]{1,0} 
      dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...),
      dynamic_slice_sizes={2,2}}

  > %dynamic-slice.2 = f32[2,2]{1,0} 
      dynamic-slice(f32[4,4]{1,0} %param, s32[] ...),
      dynamic_slice_sizes={2,2}
  > %all-gather = f32[2,4]{1,0} 
      all-gather(f32[2,2]{1,0} %dynamic-slice.2), 
      channel_id=1, 
      replica_groups={{0,1},{2,3}}, 
      dimensions={1}, 
      use_global_device_ids=true
  > %all-gather.1 = f32[4,4]{1,0} 
      all-gather(f32[2,4]{1,0} %all-gather), 
      channel_id=2, 
      replica_groups={{0,2},{1,3}}, 
      dimensions={0}, 
      use_global_device_ids=true
  > %cholesky.0 = f32[4,4]{1,0} 
      cholesky(f32[4,4]{1,0} %all-gather.1)
  > ROOT %dynamic-slice.5 = f32[2,2]{1,0} 
      dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...),
      dynamic_slice_sizes={2,2}

GSPMD fully-replicates cholesky-factorization dimensions before applying cholesky function.

As a result, GSPMD partitioner adds extra all-gathers when an explicit reshards added (compared to when not-added) for the case that the operand is not sharded but the result is sharded on the cholesky-factorization dimensions.

Hence, for the moment, turn off explicit reshard inserting for CholeskyOp.

Comparing the output of GSPMD partitioner for:

```
func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) {
  %0 = stablehlo.cholesky %arg0 {
    sdy.sharding = [{"x"}, {"y"}],
    sdy.sharding_rule = ([i, j])->([i, j])
  : tensor<4x4xf32>
  return %0 : tensor<4x4xf32>
} -> tensor<4x4xf32> [{"x"}, {"y"}]
```

vs

```
func.func @main(%arg0: tensor<4x4xf32> [{}, {}]) {
  %0 = sdy.reshard %arg0 [{"x"}, {"y"}]
  %1 = stablehlo.cholesky %0 {
    sdy.sharding = [{"x"}, {"y"}],
    sdy.sharding_rule = ([i, j])->([i, j])
  : tensor<4x4xf32>
  return %1 : tensor<4x4xf32>
} -> tensor<4x4xf32> [{"x"}, {"y"}]
```

GSPMD partitioner outputs the following diff:
```
ENTRY %main.5_spmd (param: f32[4,4]) -> f32[2,2] {
  %param = f32[4,4]{1,0} parameter(0), sharding={replicated}
  < %cholesky.0 = f32[4,4]{1,0}
      cholesky(f32[4,4]{1,0} %param)

  // Prepare for dynamic slicing from [4,4] to [2,2]
  ...

  < ROOT %dynamic-slice.2 = f32[2,2]{1,0}
      dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...),
      dynamic_slice_sizes={2,2}}

  > %dynamic-slice.2 = f32[2,2]{1,0}
      dynamic-slice(f32[4,4]{1,0} %param, s32[] ...),
      dynamic_slice_sizes={2,2}
  > %all-gather = f32[2,4]{1,0}
      all-gather(f32[2,2]{1,0} %dynamic-slice.2),
      channel_id=1,
      replica_groups={{0,1},{2,3}},
      dimensions={1},
      use_global_device_ids=true
  > %all-gather.1 = f32[4,4]{1,0}
      all-gather(f32[2,4]{1,0} %all-gather),
      channel_id=2,
      replica_groups={{0,2},{1,3}},
      dimensions={0},
      use_global_device_ids=true
  > %cholesky.0 = f32[4,4]{1,0}
      cholesky(f32[4,4]{1,0} %all-gather.1)
  > ROOT %dynamic-slice.5 = f32[2,2]{1,0}
      dynamic-slice(f32[4,4]{1,0} %cholesky.0, ...),
      dynamic_slice_sizes={2,2}
```

PiperOrigin-RevId: 712518126
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant