Skip to content

[platform_dependent] Ensure that platform_dependent only lowers for intended platforms #28607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ def jaxpr_to_checkify_jaxpr(
out_tree, error_effects = metadata()
return checked_jaxpr, out_tree, error_effects

def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
def cond_error_check(error: Error, enabled_errors, index, *ops,
branches, **params):
# Get the error-effects out of all branches so the cond can be called with
# a merged error with all these effects.
err_vals, err_tree = jtu.tree_flatten(error)
Expand All @@ -780,7 +781,7 @@ def get_error_effects_from_jaxpr(jxpr):

err_and_outs = lax.cond_p.bind(
index, *err_vals, *ops,
branches=tuple(new_branches))
branches=tuple(new_branches), **params)

# we need to merge metadata across out_trees (a tuple)
err0, out = tree_unflatten(out_trees[0], err_and_outs)
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,11 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
return ('tpu',)
return ()

def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]:
"""The lowering platforms for the current eqn"""
return tuple((_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or
ctx.platforms or ctx.module_context.platforms))


def lower_per_platform(ctx: LoweringRuleContext,
description: str,
Expand Down Expand Up @@ -2122,8 +2127,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or
ctx.platforms or ctx.module_context.platforms)
platforms: Sequence[str] = _platforms_for_eqn(ctx)
# Special case the common case (single-platform lowering)
if len(platforms) == 1:
rule = platform_rules.get(platforms[0], default_rule)
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/control_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
while_p as while_p,
)
from jax._src.lax.control_flow.conditionals import (
BranchesPlatforms as BranchesPlatforms,
cond as cond,
cond_p as cond_p,
switch as switch,
Expand Down
176 changes: 116 additions & 60 deletions jax/_src/lax/control_flow/conditionals.py

Large diffs are not rendered by default.

12 changes: 4 additions & 8 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lax.control_flow import for_loop, BranchesPlatforms
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
Expand Down Expand Up @@ -3098,7 +3098,7 @@ def _while_lowering_rule(


@register_lowering_rule(lax.cond_p)
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params):
index, *args = args
constant_index = _fold_and_get_constant_value(index)

Expand Down Expand Up @@ -3866,17 +3866,13 @@ def _pad(val):
def _platform_index_lowering(
ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool,
platforms: BranchesPlatforms,
):
for i, ps in enumerate(platforms):
# note - slightly odd structure here, as platforms is a seq[seq[str]]
if "mosaic" in ps:
if "mosaic" in ps or ps is None:
return ir_constant(i)

if has_default:
return ir_constant(len(platforms))

raise NotImplementedError(
"No mosaic or default platform indexing rule found."
)
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2558,7 +2558,10 @@ def _while_lowering_rule(
@register_lowering_rule(lax.cond_p,
mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp)
@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches,
**params):
if params:
raise NotImplementedError("platform_dependent cond")
index_aval, *_arg_avals = ctx.avals_in

def _yielded_values(outs, avals):
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3062,8 +3062,11 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:


def _cond(
index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr]
index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
**params
) -> Sequence[TfVal]:
if params:
raise NotImplementedError("jax2tf conversion for platform_dependent")
# tf.cond needs lambdas with no arguments.
branches_tf = [
partial(_interpret_jaxpr, jaxpr, *operands,
Expand Down
81 changes: 64 additions & 17 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax._src import dispatch
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax.control_flow import for_loop
from jax._src.interpreters import batching
from jax._src.interpreters import mlir

jax.config.parse_flags_with_absl()
Expand Down Expand Up @@ -137,6 +139,36 @@ def scan_reference(f, init, xs):
lambda ctx, x: mlir.hlo.CustomCallOp(
[x.type], [x],
call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results)
batching.primitive_batchers[prim_non_existent_custom_call] = (
lambda batched_args, batch_dims: (prim_non_existent_custom_call.bind(batched_args[0]),
batch_dims[0]))

# A JAX primitive that triggers error when lowering on unintended platforms
prim_with_lowering_error = core.Primitive("__testing_prim_with_lowering_error")
prim_with_lowering_error.def_abstract_eval(lambda x_aval, **_: x_aval)
def prim_with_lowering_error_lowering(platform: str,
ctx: mlir.LoweringRuleContext, x, *,
only_on: str):
if platform != only_on:
raise ValueError(f"prim_with_lowering_error with only_on={only_on} lowered for {platform}")
return mlir.hlo.SineOp(x).results
def prim_with_lowering_error_batch_rule(batched_args, batch_dims, **params):
xs, = batched_args
xs_bdim, = batch_dims
return prim_with_lowering_error.bind(xs, **params), xs_bdim

batching.primitive_batchers[prim_with_lowering_error] = prim_with_lowering_error_batch_rule

mlir.register_lowering(
prim_with_lowering_error,
partial(prim_with_lowering_error_lowering, "cpu"),
platform="cpu")
mlir.register_lowering(
prim_with_lowering_error,
partial(prim_with_lowering_error_lowering, "tpu"),
platform="tpu")
prim_with_lowering_error.def_impl(partial(dispatch.apply_primitive,
prim_with_lowering_error))


class LaxControlFlowTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -1378,7 +1410,7 @@ def f(x):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondGrad2(self, cond):
def testCondGrad2(self, cond=cond_with_new_checkpoint):
def f_ref(x):
z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x)
return z.sum()
Expand Down Expand Up @@ -2905,18 +2937,13 @@ def f(x):
x = np.arange(3, dtype=np.float32)
lowered = jax.jit(f).lower(x)
stablehlo = lowered.as_text()
self.assertIn("stablehlo.case", stablehlo)
self.assertIn("stablehlo.sine", stablehlo)
self.assertIn("stablehlo.cosine", stablehlo)

# The HLO has been canonicalized and contains only the branch we need
hlo = lowered.as_text("hlo")
# The StableHLO contains only the branch we need
if jtu.device_under_test() == "cpu":
self.assertIn(" sine", hlo)
self.assertNotIn(" cosine", hlo)
self.assertIn("stablehlo.sine", stablehlo)
self.assertNotIn("stablehlo.cosine", stablehlo)
else:
self.assertNotIn(" sine", hlo)
self.assertIn(" cosine", hlo)
self.assertNotIn("stablehlo.sine", stablehlo)
self.assertIn("stablehlo.cosine", stablehlo)

def test_platform_dependent_with_non_existent_custom_call(self):
if not jtu.test_device_matches(["cpu"]):
Expand All @@ -2939,8 +2966,7 @@ def f(x):

x = np.arange(3, dtype=np.float32)
hlo = str(jax.jit(f).lower(x).compiler_ir())
occurrences = re.findall(prim_non_existent_custom_call.name, hlo)
self.assertLen(occurrences, 3)
self.assertNotIn(prim_non_existent_custom_call.name, hlo)

res_eager = f(x)
self.assertAllClose(res_eager, 3. * np.sin(x))
Expand All @@ -2956,6 +2982,26 @@ def f(x):
res_grad = jax.grad(f)(1.)
self.assertAllClose(res_grad, 3. * np.cos(1.))

def test_platform_dependent_with_primitive_with_lowering_error(self):
if not jtu.test_device_matches(["cpu", "tpu"]):
self.skipTest("Only for CPU and TPU")

def f(x):
return lax.platform_dependent(
x,
# Check that we only lower on the intended platform
cpu=lambda x: prim_with_lowering_error.bind(x, only_on="cpu"),
tpu=lambda x: prim_with_lowering_error.bind(x, only_on="tpu"))

self.assertAllClose(np.sin(1.), f(1.)) # Eager
self.assertAllClose(np.sin(1.), jax.jit(f)(1.))
self.assertAllClose(np.sin(1.), lax.cond(True, f, lambda x: x, 1.))
self.assertAllClose(1., lax.cond(False, f, lambda x: x, 1.))
self.assertAllClose((0., np.sin(np.arange(8.))),
lax.scan(lambda carry, x: (carry, f(x)),
0., np.arange(8.)))
self.assertAllClose(np.sin(np.arange(8.)), jax.vmap(f)(np.arange(8.)))

def test_platform_dependent_multiple_identical_branches(self):
x = np.arange(3, dtype=np.float32)
def f(x):
Expand All @@ -2965,13 +3011,14 @@ def f(x):
tpu=jnp.sin,
default=lambda x: x)
res = f(x)
on_cpu_tpu = jtu.device_under_test() in ["cpu", "tpu"]
self.assertAllClose(
res,
np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x)
# We only lower the common branches once
np.sin(x) if on_cpu_tpu else x)

stablehlo = jax.jit(f).lower(x).as_text()
sines = re.findall(r"stablehlo.sine", stablehlo)
self.assertEqual(1, len(sines))
self.assertEqual(1 if on_cpu_tpu else 0, len(sines))

def test_platform_dependent_no_default(self):
ctx = contextlib.ExitStack()
Expand Down Expand Up @@ -3059,7 +3106,7 @@ def test_scan_unroll_concrete_error(self):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(True, 1.)

def test_cond_vmap_forwarding_doesnt_promote(self):
def test_condvmap_forwarding_doesnt_promote(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: was this meant to be included in this PR?

def f(x, y):
x, y = jax.lax.cond(
x < 3,
Expand Down
Loading