From f2121a72fc97c555eda6f519dba28dfe883e62cb Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 8 May 2025 14:39:59 +0300 Subject: [PATCH] [platform_dependent] Ensure that platform_dependent only lowers for intended platforms Fixes: #28594 Currently `lax.platform_dependent` allows specifying code that behaves differently when lowered on different platforms. However, this function operates in a confusing way, in that it will create a branch on the platform, but will lower all branches for the **current** lowering platforms. For example, in the following code: ``` lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu) ``` If we lower for CPU, we lower both `for_cpu` and `for_tpu` for CPU (!), but only the branch corresponding to `for_cpu` will actually run. This is a problem if, e.g., `for_tpu` does not have a lowering for CPU. We will get an error during lowering. Instead there should be no error during lowering, because that branch is not actually needed. We add a new test `test_platform_dependent_with_primitive_with_lowering_error` to demonstrate this. The solution implememented here is the Solution A from #28594: we add a `branches_platform` param to the `cond` primitive, which is propagated by all transformations. This param is used only for the conditionals arising from `lax.platform_dependendet`. During lowering we drop the branches corresponding to the platforms that are not interesting. --- jax/_src/checkify.py | 5 +- jax/_src/interpreters/mlir.py | 8 +- jax/_src/lax/control_flow/__init__.py | 1 + jax/_src/lax/control_flow/conditionals.py | 176 ++++++++++++++-------- jax/_src/pallas/mosaic/lowering.py | 12 +- jax/_src/pallas/mosaic_gpu/lowering.py | 5 +- jax/experimental/jax2tf/jax2tf.py | 5 +- tests/lax_control_flow_test.py | 79 ++++++++-- 8 files changed, 201 insertions(+), 90 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5a6456762db7..144cbaf5cd21 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -759,7 +759,8 @@ def jaxpr_to_checkify_jaxpr( out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects -def cond_error_check(error: Error, enabled_errors, index, *ops, branches): +def cond_error_check(error: Error, enabled_errors, index, *ops, + branches, **params): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) @@ -780,7 +781,7 @@ def get_error_effects_from_jaxpr(jxpr): err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, - branches=tuple(new_branches)) + branches=tuple(new_branches), **params) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e9deb8d3fff9..f6ef5787ccbf 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2080,6 +2080,11 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return ('tpu',) return () +def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: + """The lowering platforms for the current eqn""" + return tuple((_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms)) + def lower_per_platform(ctx: LoweringRuleContext, description: str, @@ -2122,8 +2127,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or - ctx.platforms or ctx.module_context.platforms) + platforms: Sequence[str] = _platforms_for_eqn(ctx) # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index f89e4d53a476..44ee94e14ca2 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -34,6 +34,7 @@ while_p as while_p, ) from jax._src.lax.control_flow.conditionals import ( + BranchesPlatforms as BranchesPlatforms, cond as cond, cond_p as cond_p, switch as switch, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 99fa72421ea1..d875989921d0 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -46,6 +46,7 @@ from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.traceback_util import api_boundary +from jax._src.typing import ArrayLike from jax._src.util import safe_map, split_list, partition_list, unzip2 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -127,9 +128,17 @@ def switch(index, branches, *operands): lo = np.array(0, np.int32) hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) + return _switch_internal(index, branches, operands, + branches_platforms=None) + +def _switch_internal( + index: ArrayLike, + branches: Sequence[Callable], + operands: Sequence[ArrayLike], *, + branches_platforms: BranchesPlatforms | None): if (config.disable_jit.value and core.is_concrete(index)): - return branches[int(index)](*operands) + return branches[int(index)](*operands) # type: ignore dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] @@ -159,7 +168,10 @@ def switch(index, branches, *operands): raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + params = dict(branches=tuple(jaxprs)) + if branches_platforms is not None: + params["branches_platforms"] = branches_platforms + out = cond_p.bind(index, *consts, *ops, **params) out_ = iter(out) all_inputs = [*consts, *ops] @@ -464,7 +476,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(axis_data, args, dims, branches): +def _cond_batching_rule(axis_data, args, dims, *, branches, **params): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -480,6 +492,9 @@ def _cond_batching_rule(axis_data, args, dims, branches): if index_dim is not batching.not_mapped: + assert "branches_platforms" not in params, ( + "The index of a cond with branches_platforms should be a " + "platform_index and should never be mapped") # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave @@ -518,10 +533,11 @@ def _cond_batching_rule(axis_data, args, dims, branches): for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind(index, *ops, branches=branches_batched) + out = cond_p.bind(index, *ops, branches=branches_batched, + **params) return out, out_dims -def _cond_jvp(primals, tangents, branches): +def _cond_jvp(primals, tangents, *, branches, **params): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros @@ -538,14 +554,15 @@ def _cond_jvp(primals, tangents, branches): _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) - out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) + out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp, + **params) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents -def _cond_partial_eval(trace, *tracers, branches): +def _cond_partial_eval(trace, *tracers, branches, **params): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in @@ -556,7 +573,7 @@ def _cond_partial_eval(trace, *tracers, branches): if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches) + params = dict(branches=branches, **params) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] @@ -586,7 +603,8 @@ def _cond_partial_eval(trace, *tracers, branches): for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known) + out_consts_res = cond_p.bind(*in_consts, branches=branches_known, + **params) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) @@ -595,7 +613,7 @@ def _cond_partial_eval(trace, *tracers, branches): res_tracers = map(trace.new_instantiated_const, res) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals] - params = dict(branches=branches_unknown) + params = dict(branches=branches_unknown, **params) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( @@ -608,6 +626,7 @@ def _cond_partial_eval(trace, *tracers, branches): def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): index_uk, *ops_uk = unks_in branches = eqn.params['branches'] + eqn_rest_params = dict(k_v for k_v in eqn.params.items() if k_v[0] != 'branches') # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -664,7 +683,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar out_binders_known, _ = partition_list(unks_out, eqn.outvars) - params_known = dict(branches=branches_known) + params_known = dict(branches=branches_known, **eqn_rest_params) effects_known = _join_cond_effects(branches_known) eqn_known = pe.new_jaxpr_eqn( ins_known, [*out_binders_known, *res_binders], cond_p, params_known, @@ -672,7 +691,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) - params_staged = dict(branches=branches_staged) + params_staged = dict(branches=branches_staged, **eqn_rest_params) effects_staged = _join_cond_effects(branches_staged) eqn_staged = pe.new_jaxpr_eqn( [eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged, @@ -818,7 +837,7 @@ def transposed(*args): debug_info=jaxpr.jaxpr.debug_info), res_avals + jaxpr.out_avals) -def _cond_transpose(cts, *args, branches): +def _cond_transpose(cts, *args, branches, **params): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] @@ -838,7 +857,8 @@ def _cond_transpose(cts, *args, branches): res = ops[:num_res] cts = map(ad.instantiate_zeros, cts) - out = cond_p.bind(index, *res, *cts, branches=branches_trans) + out = cond_p.bind(index, *res, *cts, branches=branches_trans, + **params) assert all(map(core.typecheck, lin_in_avals, out)) out_iter = iter(out) @@ -846,7 +866,8 @@ def _cond_transpose(cts, *args, branches): assert next(out_iter, None) is None return [None] + out -def _cond_typecheck(bind_time, *in_atoms, branches): +def _cond_typecheck(bind_time, *in_atoms, branches, **params): + del params if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -900,6 +921,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects + +BranchesPlatforms = tuple[tuple[str, ...] | None, ...] +# cond_p takes an optional branches_platforms param of type `BranchesPlatforms` +# when it is a `platform_dependent` conditional. +# In that case, `branches_platforms` is a tuple as long +# as `branches` and for each branch it specifies the lowering platforms it +# corresponds to. The last element, corresponding to the last branch, +# can be `None` to represent a default match-all-lowering-platforms. +# The index argument of a `platform_dependent` cond is always a +# `platform_index` primitive. cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.skip_canonicalization = True @@ -915,7 +946,39 @@ def _cond_typecheck(bind_time, *in_atoms, branches): pe.dce_rules[cond_p] = _cond_dce_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule -def _cond_lowering(ctx, index, *args, branches): +def _cond_lowering(ctx, index, *args, branches, + **params): + if (branches_platforms := params.get("branches_platforms", None)) is not None: + branches_kept: list[core.ClosedJaxpr] = [] + index_to_kept_index: dict[int, int] = {} + for p in mlir._platforms_for_eqn(ctx): + # Each `p` must appear in exactly one branches_platforms, or in the + # last default branch. Otherwise, platform_index lowering would have + # failed already. + for b_idx, b_platforms in enumerate(branches_platforms): + if b_platforms is None or p in b_platforms: + if b_idx not in index_to_kept_index: + index_to_kept_index[b_idx] = len(branches_kept) + branches_kept.append(branches[b_idx]) + break + else: + assert False, p + + # Compute the new index into branches_keep + i32_type = ir.RankedTensorType.get([], mlir.dtype_to_ir_type(dtypes.dtype(np.int32))) + kept_index_case_op = hlo.CaseOp([i32_type], + index=index, + num_branches=len(branches)) + for i in range(len(branches)): + branch = kept_index_case_op.regions[i].blocks.append() + with ir.InsertionPoint(branch): + kept_i = np.int32(index_to_kept_index.get(i, 0)) + hlo.return_([mlir.ir_constant(kept_i)]) + + index = kept_index_case_op + branches = branches_kept + assert branches, "platform_index lowering should have failed first" + joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) @@ -952,7 +1015,8 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) @register_partial_discharge_rule(cond_p) -def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches): +def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, + branches, **params): assert not should_discharge[0], "Can't discharge the index." discharged_branches = tuple( discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0] @@ -981,7 +1045,8 @@ def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *ar if fwd is None]), ()) for branch in discharged_branches ) - out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches) + out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches, + **params) out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)]) # Insert forwarded values into reference outputs ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd) @@ -1046,50 +1111,41 @@ def other_platforms_code(*args): ... The value ``per_platform[execution_platform](*args)``. """ # Join identical branches - platform_branches: list[tuple[list[str], Callable]] = [] + branches_platforms_list: list[tuple[list[str], Callable]] = [] for pname, pbranch in per_platform.items(): + if not callable(pbranch): + raise TypeError(f"lax.platform_dependent: the '{pname}' branch must " + "be a callable.") if pname == "gpu": raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.") - for ps, b in platform_branches: + for ps, b in branches_platforms_list: if b == pbranch: ps.append(pname) break else: - platform_branches.append(([pname], pbranch)) - - platforms_lists, branches = util.unzip2(platform_branches) - platform_index = platform_index_p.bind( - platforms=tuple(tuple(ps) for ps in platforms_lists), - has_default=(default is not None)) + branches_platforms_list.append(([pname], pbranch)) + platforms_lists, branches = util.unzip2(branches_platforms_list) + branches_platforms: BranchesPlatforms = tuple(tuple(ps) for ps in platforms_lists) if default is not None: + if not callable(default): + raise TypeError("lax.platform_dependent: the 'default' branch must " + "be a callable.") branches = branches + (default,) - # Use a switch, to get the proper transformation rules for free. Since - # platform index has no dependence on the input data, it won't be vectorized - # under vmap. - # If the switch and the platform_index_p above are in the same compilation - # unit then constant-folding will remove the unnecessary branches. However, - # if we run in eager mode the switch below cannot be constant-folded and - # the compilation may fail if some of the branches contain custom calls not - # recognized on the compilation platform. Detect eager mode and keep only the - # needed branch. - try: - # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into - # core.ensure_compile_time_eval to get better single-branch selection. - platform_index_concrete = core.concrete_or_error(operator.index, platform_index) - except core.ConcretizationTypeError: - return switch(platform_index, branches, *args) - else: - assert 0 <= platform_index_concrete < len(branches) - return branches[platform_index_concrete](*args) + branches_platforms = branches_platforms + (None,) # type: ignore + platform_index = platform_index_p.bind(platforms=branches_platforms) + + if core.is_concrete(platform_index): + return branches[int(platform_index)](*args) + return _switch_internal(platform_index, branches, args, + branches_platforms=branches_platforms) + # A primitive to compute the index of a platform into a list of platforms. # Args: -# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform -# names. If the current lowering platform is in one of the inner sequences -# returns the index of that inner sequence in the outer sequence. -# has_default: if True, and if the lowering platform is not found in -# `platforms` then return `len(platforms)`. Otherwise, raise an error. +# platforms: BranchesPlatforms. If the current lowering +# platform is in one of the inner tuples returns the index of that inner +# tuple in the outer tuple. platform_index_p = core.Primitive("platform_index") platform_index_p.multiple_results = False platform_index_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -1101,25 +1157,25 @@ def _platform_index_aval(*_, **__): def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool): - def lower_constant( - ctx: mlir.LoweringRuleContext, *, i: int - ) -> Sequence[ir.Value]: + platforms: BranchesPlatforms): + def lower_constant(ctx: mlir.LoweringRuleContext, *, + i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - assert isinstance(v, ir.Value), v return [v] + platform_rules: dict[str, mlir.LoweringRule] = {} + default_rule = None for i, ps in enumerate(platforms): rule = partial(lower_constant, i=i) - for p in ps: - platform_rules[p] = rule + if ps is None: + default_rule = rule + else: + for p in ps: + platform_rules[p] = rule - default_rule = ( - partial(lower_constant, i=len(platforms)) if has_default else None) return mlir.lower_per_platform( ctx, - f"platform_index(platforms={platforms}, has_default={has_default})", + f"platform_index(platforms={platforms})", platform_rules, default_rule, effects.no_effects) mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1ea5a048a17e..bba49c75f9df 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -47,7 +47,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal -from jax._src.lax.control_flow import for_loop +from jax._src.lax.control_flow import for_loop, BranchesPlatforms from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -3100,7 +3100,7 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p) -def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): +def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params): index, *args = args constant_index = _fold_and_get_constant_value(index) @@ -3870,17 +3870,13 @@ def _pad(val): def _platform_index_lowering( ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool, + platforms: BranchesPlatforms, ): for i, ps in enumerate(platforms): # note - slightly odd structure here, as platforms is a seq[seq[str]] - if "mosaic" in ps: + if "mosaic" in ps or ps is None: return ir_constant(i) - if has_default: - return ir_constant(len(platforms)) - raise NotImplementedError( "No mosaic or default platform indexing rule found." ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b501693bf627..9ead4f16c1a6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2598,7 +2598,10 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) @register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) -def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches, + **params): + if params: + raise NotImplementedError("platform_dependent cond") index_aval, *_arg_avals = ctx.avals_in def _yielded_values(outs, avals): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 4c2f35a95c57..786e021e2ff0 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3062,8 +3062,11 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: def _cond( - index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr] + index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], + **params ) -> Sequence[TfVal]: + if params: + raise NotImplementedError("jax2tf conversion for platform_dependent") # tf.cond needs lambdas with no arguments. branches_tf = [ partial(_interpret_jaxpr, jaxpr, *operands, diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 422ef769e392..d32d761ee1fa 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -37,8 +37,10 @@ from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp +from jax._src import dispatch from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import batching from jax._src.interpreters import mlir jax.config.parse_flags_with_absl() @@ -137,6 +139,36 @@ def scan_reference(f, init, xs): lambda ctx, x: mlir.hlo.CustomCallOp( [x.type], [x], call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) +batching.primitive_batchers[prim_non_existent_custom_call] = ( + lambda batched_args, batch_dims: (prim_non_existent_custom_call.bind(batched_args[0]), + batch_dims[0])) + +# A JAX primitive that triggers error when lowering on unintended platforms +prim_with_lowering_error = core.Primitive("__testing_prim_with_lowering_error") +prim_with_lowering_error.def_abstract_eval(lambda x_aval, **_: x_aval) +def prim_with_lowering_error_lowering(platform: str, + ctx: mlir.LoweringRuleContext, x, *, + only_on: str): + if platform != only_on: + raise ValueError(f"prim_with_lowering_error with only_on={only_on} lowered for {platform}") + return mlir.hlo.SineOp(x).results +def prim_with_lowering_error_batch_rule(batched_args, batch_dims, **params): + xs, = batched_args + xs_bdim, = batch_dims + return prim_with_lowering_error.bind(xs, **params), xs_bdim + +batching.primitive_batchers[prim_with_lowering_error] = prim_with_lowering_error_batch_rule + +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "cpu"), + platform="cpu") +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "tpu"), + platform="tpu") +prim_with_lowering_error.def_impl(partial(dispatch.apply_primitive, + prim_with_lowering_error)) class LaxControlFlowTest(jtu.JaxTestCase): @@ -1378,7 +1410,7 @@ def f(x): @parameterized.named_parameters( {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) - def testCondGrad2(self, cond): + def testCondGrad2(self, cond=cond_with_new_checkpoint): def f_ref(x): z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x) return z.sum() @@ -2905,18 +2937,13 @@ def f(x): x = np.arange(3, dtype=np.float32) lowered = jax.jit(f).lower(x) stablehlo = lowered.as_text() - self.assertIn("stablehlo.case", stablehlo) - self.assertIn("stablehlo.sine", stablehlo) - self.assertIn("stablehlo.cosine", stablehlo) - - # The HLO has been canonicalized and contains only the branch we need - hlo = lowered.as_text("hlo") + # The StableHLO contains only the branch we need if jtu.device_under_test() == "cpu": - self.assertIn(" sine", hlo) - self.assertNotIn(" cosine", hlo) + self.assertIn("stablehlo.sine", stablehlo) + self.assertNotIn("stablehlo.cosine", stablehlo) else: - self.assertNotIn(" sine", hlo) - self.assertIn(" cosine", hlo) + self.assertNotIn("stablehlo.sine", stablehlo) + self.assertIn("stablehlo.cosine", stablehlo) def test_platform_dependent_with_non_existent_custom_call(self): if not jtu.test_device_matches(["cpu"]): @@ -2939,8 +2966,7 @@ def f(x): x = np.arange(3, dtype=np.float32) hlo = str(jax.jit(f).lower(x).compiler_ir()) - occurrences = re.findall(prim_non_existent_custom_call.name, hlo) - self.assertLen(occurrences, 3) + self.assertNotIn(prim_non_existent_custom_call.name, hlo) res_eager = f(x) self.assertAllClose(res_eager, 3. * np.sin(x)) @@ -2956,6 +2982,26 @@ def f(x): res_grad = jax.grad(f)(1.) self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_with_primitive_with_lowering_error(self): + if not jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Only for CPU and TPU") + + def f(x): + return lax.platform_dependent( + x, + # Check that we only lower on the intended platform + cpu=lambda x: prim_with_lowering_error.bind(x, only_on="cpu"), + tpu=lambda x: prim_with_lowering_error.bind(x, only_on="tpu")) + + self.assertAllClose(np.sin(1.), f(1.)) # Eager + self.assertAllClose(np.sin(1.), jax.jit(f)(1.)) + self.assertAllClose(np.sin(1.), lax.cond(True, f, lambda x: x, 1.)) + self.assertAllClose(1., lax.cond(False, f, lambda x: x, 1.)) + self.assertAllClose((0., np.sin(np.arange(8.))), + lax.scan(lambda carry, x: (carry, f(x)), + 0., np.arange(8.))) + self.assertAllClose(np.sin(np.arange(8.)), jax.vmap(f)(np.arange(8.))) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2965,13 +3011,14 @@ def f(x): tpu=jnp.sin, default=lambda x: x) res = f(x) + on_cpu_tpu = jtu.device_under_test() in ["cpu", "tpu"] self.assertAllClose( res, - np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x) - # We only lower the common branches once + np.sin(x) if on_cpu_tpu else x) + stablehlo = jax.jit(f).lower(x).as_text() sines = re.findall(r"stablehlo.sine", stablehlo) - self.assertEqual(1, len(sines)) + self.assertEqual(1 if on_cpu_tpu else 0, len(sines)) def test_platform_dependent_no_default(self): ctx = contextlib.ExitStack()