Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add section for jax.experimental.custom_partitioning. #307

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions docs/getting_started_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,151 @@
"A lot easier to read!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8iiY4TU-4W2r"
},
"source": [
"### `jax.experimental.custom_partitioning`\n",
"\n",
"With GSPMD, we define two routines, `propagate_user_sharding` and `infer_sharding_from_operands`, that may traverses jaxpr to return the sharding for the operands and results in order to use custom_partitioning. With Shardy, we provide `sharding_rule` corresponding to an Einsum like notation string to specify sharding rule. Here is an example, where the routine that we use custom partition for implements a batch matrix multiplication."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eGv67hUt1L6K"
},
"source": [
"We use a device array of (2M, M) to compute a matmul with the form of (...4N, 2N) x (...2N, 4N). Notice that instead of hard-coding the device array and the matrix shapes, we introduce two parameters, M and N, for specifying the shapes of the matrixes and the shapes of the device array.\n",
"\n",
"We first perform the needed setup and define the `partition` routine as we would do with GSPMD."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YnqiE35x6n3Z"
},
"outputs": [],
"source": [
"from functools import partial\n",
"from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING\n",
"\n",
"jax.config.update(\"jax_use_shardy_partitioner\", True)\n",
"\n",
"def partition(mesh, arg_shapes, result_shape):\n",
" arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)\n",
" result_sharding = result_shape.sharding\n",
" rank=len(arg_shapes[0].shape)\n",
"\n",
" def lower_fn(x, y):\n",
" axis_name = arg_shardings[1].spec[rank-2][0]\n",
" i = jax.lax.axis_index(axis_name)\n",
" z = jax.lax.psum(jax.lax.dynamic_slice_in_dim(jax.lax.dynamic_slice_in_dim(x, i * 0, N, axis=rank-2), i * N, N, axis=rank-1) @ y, (axis_name))\n",
" return z\n",
"\n",
" return mesh, lower_fn, (result_sharding), arg_shardings\n",
"\n",
"@partial(custom_partitioning)\n",
"def f(x, y):\n",
" return jnp.matmul(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HpR1TFdF0k7a"
},
"source": [
"Then, we invoke the `def_partition` API. Note that instead of providing two callbacks for parameters `infer_sharding_from_operands` and `propagate_user_sharding` as we would do with GSPMD, we provide a `sharding_rule` parameter, which is an einsum notation like string\n",
"similar to the subscripts in `jnp.einsum(\"...ij, ...jk-\u003e...ik\", x, y)`, if we would extend `jnp.einsum` to support the use of `...` for representing leading batching dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kQal0N-a54ga"
},
"outputs": [],
"source": [
"f.def_partition(\n",
" infer_sharding_from_operands=None,\n",
" propagate_user_sharding=None,\n",
" partition=partition,\n",
" sharding_rule=\"... i j, ... j k -\u003e ... i k\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bQKkgGu2LAUK"
},
"source": [
"Alternatively, we can also create an equivalent `SdyShardingRule` object for the `sharding_rule` parameter. See [Shardy document on sharding rule](https://github.com/openxla/shardy/blob/main/docs/propagation.md#operation-sharding-rule) for more details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pWjSKroDLCLK"
},
"outputs": [],
"source": [
"f.def_partition(\n",
" infer_sharding_from_operands=None,\n",
" propagate_user_sharding=None,\n",
" partition=partition,\n",
" sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i', 'j'), (BATCHING, 'j', 'k')), result_mappings=((BATCHING, 'i', 'k'),)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kVqOiV5o6f_j"
},
"source": [
"Lastly, we create a mesh, define the input matrixes `x` and `y`, run the jitted `f`, and compare the results producted by the unjitted and the jitted `f`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpenpN7xrpEN"
},
"outputs": [],
"source": [
"N = 1\n",
"M = 2\n",
"num_devices = 2 * M * M\n",
"\n",
"devices = np.array(list(jax.devices())[:num_devices])\n",
"if devices.size \u003c num_devices:\n",
" raise ValueError(f'Requires {num_devices} devices')\n",
"device_mesh = Mesh(devices.reshape((2 * M, M)), ('x', 'y'))\n",
"\n",
"sharding_x = NamedSharding(device_mesh, PartitionSpec(None, None, 'x'))\n",
"sharding_y = NamedSharding(device_mesh, PartitionSpec(None, None, 'y'))\n",
"jitted_f = jax.jit(f, in_shardings=(sharding_x, sharding_y), out_shardings=sharding_x)\n",
"\n",
"x = np.asarray(np.random.randint(0, 20, (2, 3, 4*N, 2*N)), dtype=np.float32)\n",
"y = np.asarray(np.random.randint(0, 20, (2, 3, 2*N, 4*N)), dtype=np.float32)\n",
"\n",
"result = f(x, y)\n",
"\n",
"with device_mesh:\n",
" jitted_result = jitted_f(x, y)\n",
"\n",
"for i in range(num_devices):\n",
" j = (i // M) * N\n",
" assert((np.asarray(jitted_result.addressable_shards[i].data) == result[:,:,j:j+N,:]).all())"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down