@@ -267,10 +267,82 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
267
267
1 => Shard . from_config ( arg1 , % { 0 => [ 0 .. 2 ] , 1 => [ 0 .. 0 , 1 .. 1 ] } )
268
268
} )
269
269
270
+ # This ensures the data hasn't been split
270
271
assert { [ { _id , :none , out_expr , sources } ] , _state , _cache } =
271
272
GraphSplitter . traverse ( expr , expr_shards )
272
273
273
- assert out_expr == expr
274
+ # Following assertions ensure that:
275
+ # - Shards are properly propagated to the output;
276
+ # - The expression is unchanged aside from extra metadata nodes;
277
+ # - And that the shards are set to the parameters too
278
+ assert % T {
279
+ data: % Expr {
280
+ op: :metadata ,
281
+ args: [
282
+ % T {
283
+ data: % Expr {
284
+ op: :divide ,
285
+ args: [
286
+ % T {
287
+ data: % Expr {
288
+ op: :multiply ,
289
+ args: [
290
+ % T { data: % Expr { op: :constant , args: [ 3 ] } } ,
291
+ % T { data: % Expr { op: :dot , args: [ t0 , _ , _ , t1 , _ , _ ] } }
292
+ ]
293
+ }
294
+ } ,
295
+ % T { data: % Expr { op: :constant , args: [ 4 ] } }
296
+ ]
297
+ }
298
+ } ,
299
+ % { shards: output_shards }
300
+ ]
301
+ }
302
+ } = out_expr
303
+
304
+ assert sharded_expr . data . shards == output_shards
305
+
306
+ % T {
307
+ data: % Expr {
308
+ op: :add ,
309
+ args: [
310
+ % T { data: % Expr { op: :constant , args: [ 1 ] } } ,
311
+ % T {
312
+ data: % Expr {
313
+ op: :metadata ,
314
+ args: [ % T { data: % Expr { op: :parameter , args: [ 0 ] } } , % { shards: arg0_shards } ]
315
+ }
316
+ }
317
+ ]
318
+ }
319
+ } = t0
320
+
321
+ assert % {
322
+ 0 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ] ,
323
+ 1 => [ % Shard { start: 0 , length: 3 } ]
324
+ } = arg0_shards
325
+
326
+ % T {
327
+ data: % Expr {
328
+ op: :subtract ,
329
+ args: [
330
+ % T {
331
+ data: % Expr {
332
+ op: :metadata ,
333
+ args: [ % T { data: % Expr { op: :parameter , args: [ 1 ] } } , % { shards: arg1_shards } ]
334
+ }
335
+ } ,
336
+ % T { data: % Expr { op: :constant , args: [ 2 ] } }
337
+ ]
338
+ }
339
+ } = t1
340
+
341
+ assert % {
342
+ 0 => [ % Shard { start: 0 , length: 3 } ] ,
343
+ 1 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ]
344
+ } = arg1_shards
345
+
274
346
assert Enum . all? ( sources , fn { _id , source } -> source == nil end )
275
347
end
276
348
@@ -305,13 +377,37 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do
305
377
306
378
assert { [ _ , _ ] , _state , _cache } = GraphSplitter . traverse ( expr , expr_shards )
307
379
308
- { _sharded_expr , _cache , % { expr_shards: expr_shards } } =
380
+ { sharded_expr , _cache , % { expr_shards: expr_shards } } =
309
381
ShardPropagation . traverse ( expr , % {
310
382
0 => Shard . from_config ( arg0 , % { 0 => [ 0 .. 0 , 1 .. 1 ] , 1 => [ 0 .. 2 ] } ) ,
311
383
1 => Shard . from_config ( arg1 , % { } )
312
384
} )
313
385
314
- assert { [ _ , _ ] , _state , _cache } = GraphSplitter . traverse ( expr , expr_shards )
386
+ assert { [ { _ , _ , stage_0_expr , _ } , { _ , _ , stage_1_expr , _ } ] , _state , _cache } =
387
+ GraphSplitter . traverse ( expr , expr_shards )
388
+
389
+ assert { % T { data: % Expr { op: :metadata , args: [ _left , % { shards: left_shards } ] } } ,
390
+ % T { data: % Expr { op: :metadata , args: [ _right , % { shards: right_shards } ] } } } =
391
+ stage_0_expr
392
+
393
+ assert % {
394
+ 0 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ] ,
395
+ 1 => [ % Shard { start: 0 , length: 3 } ]
396
+ } = left_shards
397
+
398
+ assert % {
399
+ 0 => [
400
+ % Shard { start: 0 , length: 1 } ,
401
+ % Shard { start: 1 , length: 1 } ,
402
+ % Shard { start: 2 , length: 1 }
403
+ ] ,
404
+ 1 => [ % Shard { start: 0 , length: 1 } , % Shard { start: 1 , length: 1 } ]
405
+ } = right_shards
406
+
407
+ assert % T { data: % Expr { op: :metadata , args: [ _out , % { shards: out_shards } ] } } =
408
+ stage_1_expr
409
+
410
+ assert out_shards == sharded_expr . data . shards
315
411
end
316
412
end
317
413
end
0 commit comments