@@ -99,8 +99,8 @@ bool IterDomainGraph::exprsMap(
9999 }
100100
101101 TORCH_INTERNAL_ASSERT (
102- first->isA <Merge>() || first->isA <Split>(),
103- " Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n " ,
102+ first->isA <Merge>() || first->isA <Split>() || first-> isA <Resize>() ,
103+ " Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n " ,
104104 first->toString ());
105105
106106 auto first_ids = ir_utils::filterByType<IterDomain>(
@@ -176,6 +176,15 @@ bool IterDomainGraph::exprsMap(
176176 }
177177 }
178178
179+ if (first->isA <Resize>()) {
180+ auto first_resize = first->as <Resize>();
181+ auto second_resize = second->as <Resize>();
182+ if (!first_resize->leftExpand ()->sameAs (second_resize->leftExpand ()) ||
183+ !first_resize->rightExpand ()->sameAs (second_resize->rightExpand ())) {
184+ return false ;
185+ }
186+ }
187+
179188 return true ;
180189}
181190
@@ -211,6 +220,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
211220 for (auto out_i : c10::irange (first_ids.size ())) {
212221 exact_nodes_.mapEntries (first_ids[out_i], second_ids[out_i]);
213222 permissive_nodes_.mapEntries (first_ids[out_i], second_ids[out_i]);
223+ permissive_resize_nodes_.mapEntries (first_ids[out_i], second_ids[out_i]);
214224 }
215225}
216226
@@ -407,6 +417,7 @@ void IterDomainGraph::build(Fusion* fusion) {
407417 auto id0 = *disjoint_set->begin ();
408418 for (auto id1 : disjoint_set->vector ()) {
409419 permissive_nodes_.mapEntries (id0, id1);
420+ permissive_resize_nodes_.mapEntries (id0, id1);
410421 exact_nodes_.mapEntries (id0, id1);
411422 sibling_sets_.mapEntries (id0, id1);
412423 }
@@ -430,8 +441,22 @@ void IterDomainGraph::build(Fusion* fusion) {
430441 // Look for matching ID transformations in producer and consumer, replay
431442 // producer as consumer. We use the symmetric API of BestEffortReplay so
432443 // that both broadcast and squeeze are handled correctly.
444+ //
445+ // Note on the boolean flags: swizzles are skipped in both
446+ // producer and consumer but resizes are not.
433447 const auto permissive_disjoint_sets =
434- BestEffortReplay::replayPasC (p_tv, c_tv, -1 , pairwise_map)
448+ BestEffortReplay::replayPasC (
449+ p_tv, c_tv, -1 , pairwise_map, true , true , false )
450+ .getIterDomainEquivalence ();
451+
452+ // Permissive-Resize map allows mappings of resize inputs and
453+ // outputs
454+ //
455+ // Note on the boolean flags: swizzles and resizes are skipped
456+ // in the permissive-resize map
457+ const auto permissive_resize_disjoint_sets =
458+ BestEffortReplay::replayPasC (
459+ p_tv, c_tv, -1 , pairwise_map, true , true , true )
435460 .getIterDomainEquivalence ();
436461
437462 // For exact mapings do not map any broadcast dimensions to
@@ -483,16 +508,12 @@ void IterDomainGraph::build(Fusion* fusion) {
483508 for (auto j : c10::irange (i + 1 , vec.size ())) {
484509 auto id2 = vec[j];
485510 if (p_ids.count (id1) && c_ids.count (id2)) {
486- consumers_.at (id1).pushBack (id2);
487- producers_.at (id2).pushBack (id1);
488511 if (idIsAComputeAtLeafDomain (id1, p_tv, c_tv) &&
489512 idIsALeafDomain (id2, c_tv)) {
490513 loop_nodes_.mapEntries (id1, id2);
491514 }
492515 }
493516 if (c_ids.count (id1) && p_ids.count (id2)) {
494- producers_.at (id1).pushBack (id2);
495- consumers_.at (id2).pushBack (id1);
496517 if (idIsAComputeAtLeafDomain (id2, p_tv, c_tv) &&
497518 idIsALeafDomain (id1, c_tv)) {
498519 loop_nodes_.mapEntries (id1, id2);
@@ -501,6 +522,31 @@ void IterDomainGraph::build(Fusion* fusion) {
501522 }
502523 }
503524 }
525+
526+ // Mostly the same as the above for the permissive map but
527+ // nothing to do for the loop map.
528+ // The producer and consumer maps are based on the most
529+ // permissive mappings, so they are set using the
530+ // permissive-resize mappings.
531+ for (auto & dset : permissive_resize_disjoint_sets.disjointSets ()) {
532+ auto & vec = dset->vector ();
533+ for (auto i : c10::irange (vec.size ())) {
534+ auto id1 = vec[i];
535+ permissive_resize_nodes_.mapEntries (id1, vec[0 ]);
536+ mapMaybeSwizzleOp (permissive_resize_nodes_, id1);
537+ for (auto j : c10::irange (i + 1 , vec.size ())) {
538+ auto id2 = vec[j];
539+ if (p_ids.count (id1) && c_ids.count (id2)) {
540+ consumers_.at (id1).pushBack (id2);
541+ producers_.at (id2).pushBack (id1);
542+ }
543+ if (c_ids.count (id1) && p_ids.count (id2)) {
544+ producers_.at (id1).pushBack (id2);
545+ consumers_.at (id2).pushBack (id1);
546+ }
547+ }
548+ }
549+ }
504550 }
505551 }
506552 }
@@ -561,7 +607,7 @@ void IterDomainGraph::build(Fusion* fusion) {
561607 for (auto expr : exprs) {
562608 auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs ());
563609 TORCH_INTERNAL_ASSERT (
564- expr->isA <Split>() || expr->isA <Merge>(),
610+ expr->isA <Split>() || expr->isA <Merge>() || expr-> isA <Resize>() ,
565611 " Wasn't expecting the expression type of:\n " ,
566612 expr->toString (),
567613 " \n to be an expression defined in an rfactor transformation." );
@@ -688,6 +734,7 @@ void IterDomainGraph::initializeId(
688734 bool is_rfactor_id,
689735 bool is_leaf_id) {
690736 permissive_nodes_.initializeSet (id);
737+ permissive_resize_nodes_.initializeSet (id);
691738 exact_nodes_.initializeSet (id);
692739 if (is_leaf_id) {
693740 loop_nodes_.initializeSet (id);
@@ -1127,6 +1174,17 @@ void ComputeAtMap::buildConcreteIds() {
11271174 auto concrete_id = computeConcreteId (first_id, IdMappingMode::LOOP);
11281175 concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
11291176 }
1177+
1178+ for (const auto & disjoint_set_shared_ptr :
1179+ id_graph_.permissiveResizeNodes ().disjointSets ()) {
1180+ TORCH_INTERNAL_ASSERT (
1181+ disjoint_set_shared_ptr->vector ().size (),
1182+ " Cannot compute concrete id of empty set." );
1183+ auto first_id = disjoint_set_shared_ptr->vector ().front ();
1184+ auto concrete_id =
1185+ computeConcreteId (first_id, IdMappingMode::PERMISSIVE_RESIZE);
1186+ concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
1187+ }
11301188}
11311189
11321190bool ComputeAtMap::areExactExprs (Expr* expr_1, Expr* expr_2) {
@@ -1349,6 +1407,8 @@ std::string ComputeAtMap::toString() const {
13491407 ss << " Loop map:\n " << idGraphNodesToString (*this , IdMappingMode::LOOP);
13501408 ss << " Permissive map:\n "
13511409 << idGraphNodesToString (*this , IdMappingMode::PERMISSIVE);
1410+ ss << " Permissive-Resize map:\n "
1411+ << idGraphNodesToString (*this , IdMappingMode::PERMISSIVE_RESIZE);
13521412 ss << " Consumer maps:\n " ;
13531413 for (auto key : getSortedKeys (id_graph_.consumers (), Statement::lessThan)) {
13541414 auto consumers = id_graph_.consumers ().at (key);
@@ -1408,6 +1468,8 @@ const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
14081468 return id_graph_.loopNodes ();
14091469 case IdMappingMode::PERMISSIVE:
14101470 return id_graph_.permissiveNodes ();
1471+ case IdMappingMode::PERMISSIVE_RESIZE:
1472+ return id_graph_.permissiveResizeNodes ();
14111473 }
14121474 TORCH_INTERNAL_ASSERT (false , " Error with mapping mode provided." );
14131475}
0 commit comments