Skip to content

Commit 40c80d7

Browse files
hawkinspjax authors
authored and
jax authors
committed
Remove jax._src from JAX namespace.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX. PiperOrigin-RevId: 473243243
1 parent bc59bd1 commit 40c80d7

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
1515
{jax-issue}`#7733`) is stable and public. See [the
1616
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
1717
for {mod}`jax.stages`.
18+
* Breaking changes
19+
* `jax._src` is no longer imported into the from the public `jax` namespace.
20+
This may break users that were using JAX internals.
1821

1922
## jax 0.3.17 (Aug 31, 2022)
2023
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).

jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,5 @@
146146
from jax import util as util
147147

148148
import jax.lib # TODO(phawkins): remove this export.
149+
150+
del jax._src

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _check_sharding_annotations(self,
103103
device_assignment = np.arange(num_partitions * num_replicas)
104104
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
105105
use_spmd_partitioning = num_partitions > 1
106-
compile_options = jax._src.lib.xla_bridge.get_compile_options(
106+
compile_options = xla_bridge.get_compile_options(
107107
num_replicas=num_replicas,
108108
num_partitions=num_partitions,
109109
device_assignment=device_assignment,

jax/experimental/sparse/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from jax.interpreters import mlir
2525
from jax.interpreters import xla
2626

27-
from jax._src.lib import gpu_solver
27+
from jax._src.lib import gpu_solver, xla_extension_version
2828

2929
import numpy as np
3030

@@ -550,6 +550,6 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1):
550550
An array with the same dtype and size as b representing the solution to
551551
the sparse linear system.
552552
"""
553-
if jax._src.lib.xla_extension_version < 86:
553+
if xla_extension_version < 86:
554554
raise ValueError('spsolve requires jaxlib version 86 or above.')
555555
return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)

0 commit comments

Comments
 (0)