@@ -267,10 +267,82 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
267267 1 => Shard . from_config ( arg1 , % { 0 => [ 0 .. 2 ] , 1 => [ 0 .. 0 , 1 .. 1 ] } )
268268 } )
269269
270+ # This ensures the data hasn't been split
270271 assert { [ { _id , :none , out_expr , sources } ] , _state , _cache } =
271272 GraphSplitter . traverse ( expr , expr_shards )
272273
273- assert out_expr == expr
274+ # Following assertions ensure that:
275+ # - Shards are properly propagated to the output;
276+ # - The expression is unchanged aside from extra metadata nodes;
277+ # - And that the shards are set to the parameters too
278+ assert % T {
279+ data: % Expr {
280+ op: :metadata ,
281+ args: [
282+ % T {
283+ data: % Expr {
284+ op: :divide ,
285+ args: [
286+ % T {
287+ data: % Expr {
288+ op: :multiply ,
289+ args: [
290+ % T { data: % Expr { op: :constant , args: [ 3 ] } } ,
291+ % T { data: % Expr { op: :dot , args: [ t0 , _ , _ , t1 , _ , _ ] } }
292+ ]
293+ }
294+ } ,
295+ % T { data: % Expr { op: :constant , args: [ 4 ] } }
296+ ]
297+ }
298+ } ,
299+ % { shards: output_shards }
300+ ]
301+ }
302+ } = out_expr
303+
304+ assert sharded_expr . data . shards == output_shards
305+
306+ % T {
307+ data: % Expr {
308+ op: :add ,
309+ args: [
310+ % T { data: % Expr { op: :constant , args: [ 1 ] } } ,
311+ % T {
312+ data: % Expr {
313+ op: :metadata ,
314+ args: [ % T { data: % Expr { op: :parameter , args: [ 0 ] } } , % { shards: arg0_shards } ]
315+ }
316+ }
317+ ]
318+ }
319+ } = t0
320+
321+ assert % {
322+ 0 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ] ,
323+ 1 => [ % Shard { start: 0 , length: 3 } ]
324+ } = arg0_shards
325+
326+ % T {
327+ data: % Expr {
328+ op: :subtract ,
329+ args: [
330+ % T {
331+ data: % Expr {
332+ op: :metadata ,
333+ args: [ % T { data: % Expr { op: :parameter , args: [ 1 ] } } , % { shards: arg1_shards } ]
334+ }
335+ } ,
336+ % T { data: % Expr { op: :constant , args: [ 2 ] } }
337+ ]
338+ }
339+ } = t1
340+
341+ assert % {
342+ 0 => [ % Shard { start: 0 , length: 3 } ] ,
343+ 1 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ]
344+ } = arg1_shards
345+
274346 assert Enum . all? ( sources , fn { _id , source } -> source == nil end )
275347 end
276348
@@ -305,13 +377,37 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
305377
306378 assert { [ _ , _ ] , _state , _cache } = GraphSplitter . traverse ( expr , expr_shards )
307379
308- { _sharded_expr , _cache , % { expr_shards: expr_shards } } =
380+ { sharded_expr , _cache , % { expr_shards: expr_shards } } =
309381 ShardPropagation . traverse ( expr , % {
310382 0 => Shard . from_config ( arg0 , % { 0 => [ 0 .. 0 , 1 .. 1 ] , 1 => [ 0 .. 2 ] } ) ,
311383 1 => Shard . from_config ( arg1 , % { } )
312384 } )
313385
314- assert { [ _ , _ ] , _state , _cache } = GraphSplitter . traverse ( expr , expr_shards )
386+ assert { [ { _ , _ , stage_0_expr , _ } , { _ , _ , stage_1_expr , _ } ] , _state , _cache } =
387+ GraphSplitter . traverse ( expr , expr_shards )
388+
389+ assert { % T { data: % Expr { op: :metadata , args: [ _left , % { shards: left_shards } ] } } ,
390+ % T { data: % Expr { op: :metadata , args: [ _right , % { shards: right_shards } ] } } } =
391+ stage_0_expr
392+
393+ assert % {
394+ 0 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ] ,
395+ 1 => [ % Shard { start: 0 , length: 3 } ]
396+ } = left_shards
397+
398+ assert % {
399+ 0 => [
400+ % Shard { start: 0 , length: 1 } ,
401+ % Shard { start: 1 , length: 1 } ,
402+ % Shard { start: 2 , length: 1 }
403+ ] ,
404+ 1 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ]
405+ } = right_shards
406+
407+ assert % T { data: % Expr { op: :metadata , args: [ _out , % { shards: out_shards } ] } } =
408+ stage_1_expr
409+
410+ assert out_shards == sharded_expr . data . shards
315411 end
316412 end
317413end
0 commit comments