Skip to content

Commit 4206732

Browse files
committed
test: assertions for shard propagation
1 parent 2b10eca commit 4206732

File tree

2 files changed

+99
-4
lines changed

2 files changed

+99
-4
lines changed

nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
33

44
alias Nx.Tensor, as: T
55
alias Nx.Defn.Expr
6-
alias Nx.Defn.ShardingCompiler.Passes.ShardPropagation
76
alias Nx.Defn.ShardingCompiler.Shard
87

98
@gather_ops [:dot]

nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,82 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
267267
1 => Shard.from_config(arg1, %{0 => [0..2], 1 => [0..0, 1..1]})
268268
})
269269

270+
# This ensures the data hasn't been split
270271
assert {[{_id, :none, out_expr, sources}], _state, _cache} =
271272
GraphSplitter.traverse(expr, expr_shards)
272273

273-
assert out_expr == expr
274+
# Following assertions ensure that:
275+
# - Shards are properly propagated to the output;
276+
# - The expression is unchanged aside from extra metadata nodes;
277+
# - And that the shards are set to the parameters too
278+
assert %T{
279+
data: %Expr{
280+
op: :metadata,
281+
args: [
282+
%T{
283+
data: %Expr{
284+
op: :divide,
285+
args: [
286+
%T{
287+
data: %Expr{
288+
op: :multiply,
289+
args: [
290+
%T{data: %Expr{op: :constant, args: [3]}},
291+
%T{data: %Expr{op: :dot, args: [t0, _, _, t1, _, _]}}
292+
]
293+
}
294+
},
295+
%T{data: %Expr{op: :constant, args: [4]}}
296+
]
297+
}
298+
},
299+
%{shards: output_shards}
300+
]
301+
}
302+
} = out_expr
303+
304+
assert sharded_expr.data.shards == output_shards
305+
306+
%T{
307+
data: %Expr{
308+
op: :add,
309+
args: [
310+
%T{data: %Expr{op: :constant, args: [1]}},
311+
%T{
312+
data: %Expr{
313+
op: :metadata,
314+
args: [%T{data: %Expr{op: :parameter, args: [0]}}, %{shards: arg0_shards}]
315+
}
316+
}
317+
]
318+
}
319+
} = t0
320+
321+
assert %{
322+
0 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}],
323+
1 => [%Shard{start: 0, length: 3}]
324+
} = arg0_shards
325+
326+
%T{
327+
data: %Expr{
328+
op: :subtract,
329+
args: [
330+
%T{
331+
data: %Expr{
332+
op: :metadata,
333+
args: [%T{data: %Expr{op: :parameter, args: [1]}}, %{shards: arg1_shards}]
334+
}
335+
},
336+
%T{data: %Expr{op: :constant, args: [2]}}
337+
]
338+
}
339+
} = t1
340+
341+
assert %{
342+
0 => [%Shard{start: 0, length: 3}],
343+
1 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}]
344+
} = arg1_shards
345+
274346
assert Enum.all?(sources, fn {_id, source} -> source == nil end)
275347
end
276348

@@ -305,13 +377,37 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
305377

306378
assert {[_, _], _state, _cache} = GraphSplitter.traverse(expr, expr_shards)
307379

308-
{_sharded_expr, _cache, %{expr_shards: expr_shards}} =
380+
{sharded_expr, _cache, %{expr_shards: expr_shards}} =
309381
ShardPropagation.traverse(expr, %{
310382
0 => Shard.from_config(arg0, %{0 => [0..0, 1..1], 1 => [0..2]}),
311383
1 => Shard.from_config(arg1, %{})
312384
})
313385

314-
assert {[_, _], _state, _cache} = GraphSplitter.traverse(expr, expr_shards)
386+
assert {[{_, _, stage_0_expr, _}, {_, _, stage_1_expr, _}], _state, _cache} =
387+
GraphSplitter.traverse(expr, expr_shards)
388+
389+
assert {%T{data: %Expr{op: :metadata, args: [_left, %{shards: left_shards}]}},
390+
%T{data: %Expr{op: :metadata, args: [_right, %{shards: right_shards}]}}} =
391+
stage_0_expr
392+
393+
assert %{
394+
0 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}],
395+
1 => [%Shard{start: 0, length: 3}]
396+
} = left_shards
397+
398+
assert %{
399+
0 => [
400+
%Shard{start: 0, length: 1},
401+
%Shard{start: 1, length: 1},
402+
%Shard{start: 2, length: 1}
403+
],
404+
1 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}]
405+
} = right_shards
406+
407+
assert %T{data: %Expr{op: :metadata, args: [_out, %{shards: out_shards}]}} =
408+
stage_1_expr
409+
410+
assert out_shards == sharded_expr.data.shards
315411
end
316412
end
317413
end

0 commit comments

Comments
 (0)