Skip to content

Commit 72b4123

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Fix misuses of jax._src.lib.mlir.
Add mpmd dialect to jax.extend.mlir.dialects. PiperOrigin-RevId: 826263298
1 parent 00afdd2 commit 72b4123

11 files changed

+65
-27
lines changed

jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ pytype_strict_library(
3232
pypi_requirement("jax"),
3333
pypi_requirement("jax/_src/lib"),
3434
pypi_requirement("jax/extend"),
35+
pypi_requirement("jax/extend/mlir"),
36+
pypi_requirement("jax/extend/mlir:ir"),
37+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
3538
pypi_requirement("numpy"),
3639
],
3740
)
@@ -43,6 +46,9 @@ pytype_strict_library(
4346
pypi_requirement("jax"),
4447
pypi_requirement("jax/_src/lib"),
4548
pypi_requirement("jax/extend"),
49+
pypi_requirement("jax/extend/mlir"),
50+
pypi_requirement("jax/extend/mlir:ir"),
51+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
4652
pypi_requirement("numpy"),
4753
],
4854
)
@@ -58,6 +64,10 @@ pytype_strict_library(
5864
pypi_requirement("jax"),
5965
pypi_requirement("jax/_src/lib"),
6066
pypi_requirement("jax/extend"),
67+
pypi_requirement("jax/extend/mlir"),
68+
pypi_requirement("jax/extend/mlir:ir"),
69+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
70+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
6171
pypi_requirement("numpy"),
6272
],
6373
)
@@ -73,6 +83,10 @@ pytype_strict_library(
7383
pypi_requirement("jax"),
7484
pypi_requirement("jax/_src/lib"),
7585
pypi_requirement("jax/extend"),
86+
pypi_requirement("jax/extend/mlir"),
87+
pypi_requirement("jax/extend/mlir:ir"),
88+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
89+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
7690
pypi_requirement("numpy"),
7791
],
7892
)
@@ -88,6 +102,10 @@ pytype_strict_library(
88102
pypi_requirement("jax"),
89103
pypi_requirement("jax/_src/lib"),
90104
pypi_requirement("jax/extend"),
105+
pypi_requirement("jax/extend/mlir"),
106+
pypi_requirement("jax/extend/mlir:ir"),
107+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
108+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
91109
pypi_requirement("numpy"),
92110
],
93111
)
@@ -101,6 +119,10 @@ pytype_strict_library(
101119
pypi_requirement("jax"),
102120
pypi_requirement("jax/_src/lib"),
103121
pypi_requirement("jax/extend"),
122+
pypi_requirement("jax/extend/mlir"),
123+
pypi_requirement("jax/extend/mlir:ir"),
124+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
125+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
104126
pypi_requirement("numpy"),
105127
],
106128
)
@@ -116,6 +138,10 @@ pytype_strict_library(
116138
pypi_requirement("jax"),
117139
pypi_requirement("jax/_src/lib"),
118140
pypi_requirement("jax/extend"),
141+
pypi_requirement("jax/extend/mlir"),
142+
pypi_requirement("jax/extend/mlir:ir"),
143+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
144+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
119145
pypi_requirement("numpy"),
120146
],
121147
)
@@ -131,6 +157,10 @@ pytype_strict_library(
131157
pypi_requirement("jax"),
132158
pypi_requirement("jax/_src/lib"),
133159
pypi_requirement("jax/extend"),
160+
pypi_requirement("jax/extend/mlir"),
161+
pypi_requirement("jax/extend/mlir:ir"),
162+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
163+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
134164
pypi_requirement("numpy"),
135165
],
136166
)
@@ -143,6 +173,11 @@ pytype_strict_library(
143173
deps = [
144174
pypi_requirement("jax"),
145175
pypi_requirement("jax/_src/lib"),
176+
pypi_requirement("jax/extend"),
177+
pypi_requirement("jax/extend/mlir"),
178+
pypi_requirement("jax/extend/mlir:ir"),
179+
pypi_requirement("jax/extend/mlir/dialects:func_dialect"),
180+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
146181
],
147182
)
148183

@@ -156,6 +191,9 @@ pytype_strict_library(
156191
pypi_requirement("jax"),
157192
pypi_requirement("jax/_src/lib"),
158193
pypi_requirement("jax/extend"),
194+
pypi_requirement("jax/extend/mlir"),
195+
pypi_requirement("jax/extend/mlir:ir"),
196+
pypi_requirement("jax/extend/mlir/dialects:stablehlo_dialect"),
159197
pypi_requirement("numpy"),
160198
],
161199
)

jax_tpu_embedding/sparsecore/lib/core/primitives/local_sparse_dense_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import jax
2020
from jax import core
21-
from jax._src.lib.mlir import ir
22-
from jax._src.lib.mlir.dialects import hlo
2321
import jax.extend as jex
22+
from jax.extend.mlir import ir
23+
from jax.extend.mlir.dialects import stablehlo as hlo
2424
from jax.interpreters import mlir
2525
from jax.interpreters import xla
2626
import jax.numpy as jnp

jax_tpu_embedding/sparsecore/lib/core/primitives/optimizers_computation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414
"""Defines common optimizers for embedding lookups."""
1515

16-
from jax._src.lib.mlir import ir
17-
from jax._src.lib.mlir.dialects import func as func_dialect
18-
from jax._src.lib.mlir.dialects import hlo
16+
from jax.extend.mlir import ir
17+
from jax.extend.mlir.dialects import func as func_dialect
18+
from jax.extend.mlir.dialects import stablehlo as hlo
1919
from jax.interpreters import mlir
2020

2121

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import jax
2020
from jax import core
21-
from jax._src.lib.mlir import ir
22-
from jax._src.lib.mlir.dialects import hlo
2321
import jax.extend as jex
22+
from jax.extend.mlir import ir
23+
from jax.extend.mlir.dialects import stablehlo as hlo
2424
from jax.interpreters import mlir
2525
from jax.interpreters import xla
2626
import jax.numpy as jnp

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_adagrad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from typing import Tuple
2626

2727
import jax
28-
from jax._src.lib.mlir import ir
29-
from jax._src.lib.mlir.dialects import func as func_dialect
30-
from jax._src.lib.mlir.dialects import hlo
3128
import jax.extend as jex
29+
from jax.extend.mlir import ir
30+
from jax.extend.mlir.dialects import func as func_dialect
31+
from jax.extend.mlir.dialects import stablehlo as hlo
3232
from jax.interpreters import mlir
3333
from jax.interpreters import xla
3434
from jax_tpu_embedding.sparsecore.lib.core import constants

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_adagrad_momentum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
from typing import Tuple
2828

2929
import jax
30-
from jax._src.lib.mlir import ir
31-
from jax._src.lib.mlir.dialects import func as func_dialect
32-
from jax._src.lib.mlir.dialects import hlo
3330
import jax.extend as jex
31+
from jax.extend.mlir import ir
32+
from jax.extend.mlir.dialects import func as func_dialect
33+
from jax.extend.mlir.dialects import stablehlo as hlo
3434
from jax.interpreters import mlir
3535
from jax.interpreters import xla
3636
from jax_tpu_embedding.sparsecore.lib.core import constants

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_adam.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
from typing import Tuple
2828

2929
import jax
30-
from jax._src.lib.mlir import ir
31-
from jax._src.lib.mlir.dialects import func as func_dialect
32-
from jax._src.lib.mlir.dialects import hlo
3330
import jax.extend as jex
31+
from jax.extend.mlir import ir
32+
from jax.extend.mlir.dialects import func as func_dialect
33+
from jax.extend.mlir.dialects import stablehlo as hlo
3434
from jax.interpreters import mlir
3535
from jax.interpreters import xla
3636
from jax_tpu_embedding.sparsecore.lib.core import constants

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_ftrl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from typing import Tuple
2929

3030
import jax
31-
from jax._src.lib.mlir import ir
32-
from jax._src.lib.mlir.dialects import func as func_dialect
33-
from jax._src.lib.mlir.dialects import hlo
3431
import jax.extend as jex
32+
from jax.extend.mlir import ir
33+
from jax.extend.mlir.dialects import func as func_dialect
34+
from jax.extend.mlir.dialects import stablehlo as hlo
3535
from jax.interpreters import mlir
3636
from jax.interpreters import xla
3737
from jax_tpu_embedding.sparsecore.lib.core import constants

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_laprop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from typing import Tuple
2727

2828
import jax
29-
from jax._src.lib.mlir import ir
30-
from jax._src.lib.mlir.dialects import func as func_dialect
31-
from jax._src.lib.mlir.dialects import hlo
3229
import jax.extend as jex
30+
from jax.extend.mlir import ir
31+
from jax.extend.mlir.dialects import func as func_dialect
32+
from jax.extend.mlir.dialects import stablehlo as hlo
3333
from jax.interpreters import mlir
3434
from jax.interpreters import xla
3535
from jax_tpu_embedding.sparsecore.lib.core import constants

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_grad_with_sgd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import json
1818

1919
import jax
20-
from jax._src.lib.mlir import ir
21-
from jax._src.lib.mlir.dialects import func as func_dialect
22-
from jax._src.lib.mlir.dialects import hlo
2320
import jax.extend as jex
21+
from jax.extend.mlir import ir
22+
from jax.extend.mlir.dialects import func as func_dialect
23+
from jax.extend.mlir.dialects import stablehlo as hlo
2424
from jax.interpreters import mlir
2525
from jax.interpreters import xla
2626
from jax_tpu_embedding.sparsecore.lib.core import constants

0 commit comments

Comments
 (0)