Skip to content

[platform_dependent] Ensure that platform_dependent only lowers for intended platforms #28607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented May 8, 2025

Fixes: #28594

Currently lax.platform_dependent allows specifying code that behaves
differently when lowered on different platforms. However, this function
operates in a confusing way, in that it will create a branch on the
platform, but will lower all branches for the current lowering platforms.

For example, in the following code:

   lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu)

If we lower for CPU, we lower both for_cpu and for_tpu
for CPU (!), but only the branch corresponding to for_cpu
will actually run.

This is a problem if, e.g., for_tpu does not have a lowering
for CPU. We will get an error during lowering. Instead there should
be no error during lowering, because that branch is not actually needed.

We add a new test test_platform_dependent_with_primitive_with_lowering_error
to demonstrate this.

The solution implememented here is the Solution A from #28594: we
add a branches_platform param to the cond primitive, which is
propagated by all transformations. This param is used only for the
conditionals arising from lax.platform_dependendet.
During lowering we drop the branches corresponding to the platforms
that are not interesting.

@gnecula gnecula self-assigned this May 8, 2025
@gnecula gnecula force-pushed the fix_platform_dependent branch from 3f0d55f to 2f0a846 Compare May 9, 2025 09:02
@gnecula gnecula added the pull ready Ready for copybara import and testing label May 9, 2025
@gnecula gnecula force-pushed the fix_platform_dependent branch 3 times, most recently from 300a845 to 63e2b65 Compare May 9, 2025 11:06
@gnecula gnecula requested review from dfm and froystig May 9, 2025 11:12
@gnecula gnecula force-pushed the fix_platform_dependent branch 3 times, most recently from 909f3fa to 4f0b117 Compare May 9, 2025 13:12
@dfm
Copy link
Contributor

dfm commented May 9, 2025

I think this looks great! On the design side, instead of threading branches_platforms through all the cond rules, I might be inclined to have a new primitive called something like platform_cond, that re-uses all the same rules. For example, we could re-write the JVP rule as:

def _cond_jvp(primitive, primals, tangents, *, branches, **params):
  ...
  out = primitive.bind(index, *ops, *ops_dot, branches=branches_jvp, **params)
  ...

ad.primitive_jvps[cond_p] = partial(_cond_jvp, cond_p)
ad.primitive_jvps[platform_cond_p] = partial(_cond_jvp, platform_cond_p)

It's a little bit more duplication of code, but my intuition is that it would be simpler to have a primitive with a name that more clearly indicates its purpose in the Jaxpr, rather than always carrying around this new parameter in the existing cond.

Besides the design question, the implementation looks great to me!

@gnecula
Copy link
Collaborator Author

gnecula commented May 10, 2025

Thank you for your review. I resonate with the concern that now all occurrences of cond in Jaxpr will have this branches_platforms=None. I am wondering if we can address that with a simpler solution. What do you think about having the branches_platforms parameter optional for the primitive, and it would be present only when it is really a platform-dependent conditional? For those instances of platform-dependent, the only visible difference from your proposal is that instead of the name of the primitive being different, it would have this branches_platforms, and the conditional is indexed by platform_index (as before). I think this should be clear enough.

In terms of the implementation this new variant is similar to your proposal in the use of **params in the rule definitions, but it does not introduce a new primitive and does not need to change the registration of all cond rules.

@gnecula gnecula force-pushed the fix_platform_dependent branch 2 times, most recently from edf0a20 to 7ef090d Compare May 10, 2025 14:11
…ntended platforms

Fixes: jax-ml#28594

Currently `lax.platform_dependent` allows specifying code that behaves
differently when lowered on different platforms. However, this function
operates in a confusing way, in that it will create a branch on the
platform, but will lower all branches for the **current** lowering platforms.

For example, in the following code:
```
   lax.platform_dependent(x,
                          cpu=for_cpu, tpu=for_tpu)
```

If we lower for CPU, we lower both `for_cpu` and `for_tpu`
for CPU (!), but only the branch corresponding to `for_cpu`
will actually run.

This is a problem if, e.g., `for_tpu` does not have a lowering
for CPU. We will get an error during lowering. Instead there should
be no error during lowering, because that branch is not actually needed.

We add a new test `test_platform_dependent_with_primitive_with_lowering_error`
to demonstrate this.

The solution implememented here is the Solution A from jax-ml#28594: we
add a `branches_platform` param to the `cond` primitive, which is
propagated by all transformations. This param is used only for the
conditionals arising from `lax.platform_dependendet`.
During lowering we drop the branches corresponding to the platforms
that are not interesting.
@gnecula gnecula force-pushed the fix_platform_dependent branch from 7ef090d to 97fc1a1 Compare May 10, 2025 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.lax.platform_dependent doesn't stop Pallas from trying to lower for other backends?
2 participants