Skip to content

Commit edf0a20

Browse files
committed
[platform_dependent] Ensure that platform_dependent only lowers for intended platforms
Fixes: #28594 Currently `lax.platform_dependent` allows specifying code that behaves differently when lowered on different platforms. However, this function operates in a confusing way, in that it will create a branch on the platform, but will lower all branches for the **current** lowering platforms. For example, in the following code: ``` lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu) ``` If we lower for CPU, we lower both `for_cpu` and `for_tpu` for CPU (!), but only the branch corresponding to `for_cpu` will actually run. This is a problem if, e.g., `for_tpu` does not have a lowering for CPU. We will get an error during lowering. Instead there should be no error during lowering, because that branch is not actually needed. We add a new test `test_platform_dependent_with_primitive_with_lowering_error` to demonstrate this. The solution implememented here is the Solution A from #28594: we add a `branches_platform` param to the `cond` primitive, which is propagated by all transformations. This param is used only for the conditionals arising from `lax.platform_dependendet`. During lowering we drop the branches corresponding to the platforms that are not interesting.
1 parent 8137c37 commit edf0a20

File tree

7 files changed

+196
-81
lines changed

7 files changed

+196
-81
lines changed

jax/_src/checkify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,8 @@ def jaxpr_to_checkify_jaxpr(
759759
out_tree, error_effects = metadata()
760760
return checked_jaxpr, out_tree, error_effects
761761

762-
def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
762+
def cond_error_check(error: Error, enabled_errors, index, *ops,
763+
branches, **params):
763764
# Get the error-effects out of all branches so the cond can be called with
764765
# a merged error with all these effects.
765766
err_vals, err_tree = jtu.tree_flatten(error)
@@ -780,7 +781,7 @@ def get_error_effects_from_jaxpr(jxpr):
780781

781782
err_and_outs = lax.cond_p.bind(
782783
index, *err_vals, *ops,
783-
branches=tuple(new_branches))
784+
branches=tuple(new_branches), **params)
784785

785786
# we need to merge metadata across out_trees (a tuple)
786787
err0, out = tree_unflatten(out_trees[0], err_and_outs)

jax/_src/interpreters/mlir.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,11 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
20802080
return ('tpu',)
20812081
return ()
20822082

2083+
def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]:
2084+
"""The lowering platforms for the current eqn"""
2085+
return tuple((_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or
2086+
ctx.platforms or ctx.module_context.platforms))
2087+
20832088

20842089
def lower_per_platform(ctx: LoweringRuleContext,
20852090
description: str,
@@ -2122,8 +2127,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
21222127
rule_args: the args of the lowering rules.
21232128
rule_kwargs: the kwargs of the lowering rules.
21242129
"""
2125-
platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or
2126-
ctx.platforms or ctx.module_context.platforms)
2130+
platforms: Sequence[str] = _platforms_for_eqn(ctx)
21272131
# Special case the common case (single-platform lowering)
21282132
if len(platforms) == 1:
21292133
rule = platform_rules.get(platforms[0], default_rule)

0 commit comments

Comments
 (0)