Skip to content

Commit 406d0d1

Browse files
committed
IterDomain resize for pad, cat, slice
Cherry-pick of csarofeen/pytorch@1e30fee Original PR: csarofeen/pytorch#2480
1 parent 48b0cb4 commit 406d0d1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+3755
-404
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ if(BUILD_TEST)
353353
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_outer_reduction.cpp)
354354
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_loop_rotation.cpp)
355355
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp)
356+
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_resize.cpp)
356357
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp)
357358
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_matmul_sass.cpp)
358359
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp)

csrc/codegen.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2774,6 +2774,41 @@ class CudaKernelGenerator : private OptOutConstDispatch {
27742774
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
27752775
}
27762776

2777+
void handle(const CatOp* cat) final {
2778+
auto out = gen(cat->output(0));
2779+
2780+
// Generate code like:
2781+
// if (consumer_idx < producer_0_extent) {
2782+
// consumer[consumer_idx] = produce_0[producer_idx0];
2783+
// } else if (consumer_idx < producer_1_extent) {
2784+
// consumer[consumer_idx] = produce_1[producer_idx1];
2785+
// } else if (consumer_idx < producer_2_extent) {
2786+
// consumer[consumer_idx] = produce_2[producer_idx2];
2787+
// } else {
2788+
// consumer[consumer_idx] = produce_3[producer_idx3];
2789+
// }
2790+
2791+
for (const auto i : c10::irange(cat->inputs().size())) {
2792+
auto inp = cat->input(i)->as<kir::TensorIndex>();
2793+
auto inp_str = gen(inp);
2794+
if (i < cat->inputs().size() - 1) {
2795+
if (i == 0) {
2796+
indent() << "if (";
2797+
} else {
2798+
indent() << "} else if (";
2799+
}
2800+
code_ << gen(cat->getPred(i)) << ") {\n";
2801+
} else {
2802+
// last case doesn't need to be predicated
2803+
indent() << "} else {\n";
2804+
}
2805+
2806+
indent() << kTab << out << " = " << gen(inp) << ";\n";
2807+
}
2808+
2809+
indent() << "}\n";
2810+
}
2811+
27772812
private:
27782813
std::stringstream code_;
27792814
const kir::Kernel* kernel_;

csrc/compute_at_map.cpp

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ bool IterDomainGraph::exprsMap(
9999
}
100100

101101
TORCH_INTERNAL_ASSERT(
102-
first->isA<Merge>() || first->isA<Split>(),
103-
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
102+
first->isA<Merge>() || first->isA<Split>() || first->isA<Resize>(),
103+
"Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n",
104104
first->toString());
105105

106106
auto first_ids = ir_utils::filterByType<IterDomain>(
@@ -176,6 +176,15 @@ bool IterDomainGraph::exprsMap(
176176
}
177177
}
178178

179+
if (first->isA<Resize>()) {
180+
auto first_resize = first->as<Resize>();
181+
auto second_resize = second->as<Resize>();
182+
if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) ||
183+
!first_resize->rightExpand()->sameAs(second_resize->rightExpand())) {
184+
return false;
185+
}
186+
}
187+
179188
return true;
180189
}
181190

@@ -211,6 +220,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
211220
for (auto out_i : c10::irange(first_ids.size())) {
212221
exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
213222
permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
223+
permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
214224
}
215225
}
216226

@@ -407,6 +417,7 @@ void IterDomainGraph::build(Fusion* fusion) {
407417
auto id0 = *disjoint_set->begin();
408418
for (auto id1 : disjoint_set->vector()) {
409419
permissive_nodes_.mapEntries(id0, id1);
420+
permissive_resize_nodes_.mapEntries(id0, id1);
410421
exact_nodes_.mapEntries(id0, id1);
411422
sibling_sets_.mapEntries(id0, id1);
412423
}
@@ -430,8 +441,22 @@ void IterDomainGraph::build(Fusion* fusion) {
430441
// Look for matching ID transformations in producer and consumer, replay
431442
// producer as consumer. We use the symmetric API of BestEffortReplay so
432443
// that both broadcast and squeeze are handled correctly.
444+
//
445+
// Note on the boolean flags: swizzles are skipped in both
446+
// producer and consumer but resizes are not.
433447
const auto permissive_disjoint_sets =
434-
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
448+
BestEffortReplay::replayPasC(
449+
p_tv, c_tv, -1, pairwise_map, true, true, false)
450+
.getIterDomainEquivalence();
451+
452+
// Permissive-Resize map allows mappings of resize inputs and
453+
// outputs
454+
//
455+
// Note on the boolean flags: swizzles and resizes are skipped
456+
// in the permissive-resize map
457+
const auto permissive_resize_disjoint_sets =
458+
BestEffortReplay::replayPasC(
459+
p_tv, c_tv, -1, pairwise_map, true, true, true)
435460
.getIterDomainEquivalence();
436461

437462
// For exact mapings do not map any broadcast dimensions to
@@ -483,16 +508,12 @@ void IterDomainGraph::build(Fusion* fusion) {
483508
for (auto j : c10::irange(i + 1, vec.size())) {
484509
auto id2 = vec[j];
485510
if (p_ids.count(id1) && c_ids.count(id2)) {
486-
consumers_.at(id1).pushBack(id2);
487-
producers_.at(id2).pushBack(id1);
488511
if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) &&
489512
idIsALeafDomain(id2, c_tv)) {
490513
loop_nodes_.mapEntries(id1, id2);
491514
}
492515
}
493516
if (c_ids.count(id1) && p_ids.count(id2)) {
494-
producers_.at(id1).pushBack(id2);
495-
consumers_.at(id2).pushBack(id1);
496517
if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) &&
497518
idIsALeafDomain(id1, c_tv)) {
498519
loop_nodes_.mapEntries(id1, id2);
@@ -501,6 +522,31 @@ void IterDomainGraph::build(Fusion* fusion) {
501522
}
502523
}
503524
}
525+
526+
// Mostly the same as the above for the permissive map but
527+
// nothing to do for the loop map.
528+
// The producer and consumer maps are based on the most
529+
// permissive mappings, so they are set using the
530+
// permissive-resize mappings.
531+
for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) {
532+
auto& vec = dset->vector();
533+
for (auto i : c10::irange(vec.size())) {
534+
auto id1 = vec[i];
535+
permissive_resize_nodes_.mapEntries(id1, vec[0]);
536+
mapMaybeSwizzleOp(permissive_resize_nodes_, id1);
537+
for (auto j : c10::irange(i + 1, vec.size())) {
538+
auto id2 = vec[j];
539+
if (p_ids.count(id1) && c_ids.count(id2)) {
540+
consumers_.at(id1).pushBack(id2);
541+
producers_.at(id2).pushBack(id1);
542+
}
543+
if (c_ids.count(id1) && p_ids.count(id2)) {
544+
producers_.at(id1).pushBack(id2);
545+
consumers_.at(id2).pushBack(id1);
546+
}
547+
}
548+
}
549+
}
504550
}
505551
}
506552
}
@@ -561,7 +607,7 @@ void IterDomainGraph::build(Fusion* fusion) {
561607
for (auto expr : exprs) {
562608
auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
563609
TORCH_INTERNAL_ASSERT(
564-
expr->isA<Split>() || expr->isA<Merge>(),
610+
expr->isA<Split>() || expr->isA<Merge>() || expr->isA<Resize>(),
565611
"Wasn't expecting the expression type of:\n",
566612
expr->toString(),
567613
"\nto be an expression defined in an rfactor transformation.");
@@ -688,6 +734,7 @@ void IterDomainGraph::initializeId(
688734
bool is_rfactor_id,
689735
bool is_leaf_id) {
690736
permissive_nodes_.initializeSet(id);
737+
permissive_resize_nodes_.initializeSet(id);
691738
exact_nodes_.initializeSet(id);
692739
if (is_leaf_id) {
693740
loop_nodes_.initializeSet(id);
@@ -1127,6 +1174,17 @@ void ComputeAtMap::buildConcreteIds() {
11271174
auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP);
11281175
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
11291176
}
1177+
1178+
for (const auto& disjoint_set_shared_ptr :
1179+
id_graph_.permissiveResizeNodes().disjointSets()) {
1180+
TORCH_INTERNAL_ASSERT(
1181+
disjoint_set_shared_ptr->vector().size(),
1182+
"Cannot compute concrete id of empty set.");
1183+
auto first_id = disjoint_set_shared_ptr->vector().front();
1184+
auto concrete_id =
1185+
computeConcreteId(first_id, IdMappingMode::PERMISSIVE_RESIZE);
1186+
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
1187+
}
11301188
}
11311189

11321190
bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
@@ -1349,6 +1407,8 @@ std::string ComputeAtMap::toString() const {
13491407
ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP);
13501408
ss << "Permissive map:\n"
13511409
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE);
1410+
ss << "Permissive-Resize map:\n"
1411+
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE);
13521412
ss << "Consumer maps:\n";
13531413
for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) {
13541414
auto consumers = id_graph_.consumers().at(key);
@@ -1408,6 +1468,8 @@ const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
14081468
return id_graph_.loopNodes();
14091469
case IdMappingMode::PERMISSIVE:
14101470
return id_graph_.permissiveNodes();
1471+
case IdMappingMode::PERMISSIVE_RESIZE:
1472+
return id_graph_.permissiveResizeNodes();
14111473
}
14121474
TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided.");
14131475
}

csrc/compute_at_map.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ namespace nvfuser {
5353
// Map all iteration domains
5454
// Always contain root mappings (otherwise they could have been forwarded in
5555
// broadcast)
56+
// IdMappingMode::PERMISSIVE_RESIZE
57+
// Include everything in PERMISSIVE. Map also domains that are
58+
// inputs and outputs of resize ops. Used for, e.g., propagating
59+
// parallel types across those domains.
5660
// IdMappingMode::EXACT
5761
// Don't map any broadcast axes to non-broadcast axes
5862
// Do not forward through any broadcast IDs
@@ -79,6 +83,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {
7983
const DisjointSets<IterDomain*>& loopNodes() const {
8084
return loop_nodes_;
8185
}
86+
const DisjointSets<IterDomain*>& permissiveResizeNodes() const {
87+
return permissive_resize_nodes_;
88+
}
8289

8390
// Consumers and producers is not symmetric like the other sets
8491
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
@@ -132,8 +139,11 @@ class TORCH_CUDA_CU_API IterDomainGraph {
132139
DisjointSets<IterDomain*> exact_nodes_;
133140
DisjointSets<IterDomain*> almost_exact_nodes_;
134141
DisjointSets<IterDomain*> loop_nodes_;
142+
DisjointSets<IterDomain*> permissive_resize_nodes_;
135143

136-
// Consumers and producers is not symmetric like the other sets
144+
// Consumers and producers is not symmetric like the other sets.
145+
// Mapping is based on the most permissive map, i.e., the
146+
// permissive-resize map.
137147
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
138148
consumers_;
139149
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>

csrc/contiguity.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) {
331331
}
332332
}
333333

334+
void OrderedIdInformation::handle(Resize* resize) {
335+
// Find inputs in the active_ids_ vector
336+
const auto in_it =
337+
std::find(active_ids_.begin(), active_ids_.end(), resize->in());
338+
339+
if (in_it == active_ids_.end()) {
340+
return;
341+
}
342+
343+
auto in_pos = std::distance(active_ids_.begin(), in_it);
344+
345+
// Find inputs in the ordered transforms map
346+
const auto in_ordered_it = consistently_ordered_ids_.find(resize->in());
347+
348+
bool in_ordered = in_ordered_it != consistently_ordered_ids_.end();
349+
350+
// Get root ids of the two inputs
351+
const auto in_root_ids_it = id_to_root_ids_.find(resize->in());
352+
353+
TORCH_INTERNAL_ASSERT(
354+
in_root_ids_it != id_to_root_ids_.end(),
355+
"Error replaying transforms in contiguous ID checker.");
356+
357+
const auto& in_root_ids = in_root_ids_it->second;
358+
359+
// Update map for outputs
360+
// Remove inputs from the active_ids_ and insert the output ID
361+
active_ids_[in_pos] = resize->out();
362+
363+
// Not completely certain, but propagating these properties should e
364+
// fine
365+
if (in_ordered) {
366+
consistently_ordered_ids_.emplace(resize->out());
367+
}
368+
369+
if (exclusivelyConsumesRoots(resize->in())) {
370+
exclusively_consumes_roots_.emplace(resize->out());
371+
}
372+
373+
id_to_root_ids_[resize->out()] = in_root_ids;
374+
}
375+
334376
NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
335377
// TODO: Revisit reduction rfactor axes and propagation. Should probably use
336378
// ca_map to propogate non divisibility dependencies across exact map. Still
@@ -500,6 +542,19 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
500542
{root_domain_.begin(), root_domain_.end()},
501543
{ids.begin(), ids.end()});
502544
for (auto expr : exprs) {
545+
if (auto resize = dynamic_cast<Resize*>(expr)) {
546+
resize_deps_.insert(resize->out());
547+
} else {
548+
if (std::any_of(
549+
expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) {
550+
return inp->isA<IterDomain>() &&
551+
resize_deps_.count(inp->as<IterDomain>());
552+
})) {
553+
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
554+
resize_deps_.insert(out);
555+
}
556+
}
557+
}
503558
handle(expr);
504559
}
505560
}
@@ -576,6 +631,12 @@ void ContigIDs::handle(Merge* merge) {
576631
return;
577632
}
578633

634+
// Don't allow contig indexing after resize as we need traverse back
635+
// at least to direct outputs of resize ops
636+
if (resize_deps_.count(merge->out())) {
637+
return;
638+
}
639+
579640
// All broadcasting
580641
if (last_root == nullptr) {
581642
return;

csrc/contiguity.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class OrderedIdInformation : public OptInDispatch {
6262

6363
void handle(Swizzle2D* swizzle) override;
6464

65+
void handle(Resize* resize) override;
66+
6567
// Track which root ids were used to generate each iter domain
6668
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
6769
id_to_root_ids_;
@@ -255,6 +257,8 @@ class ContigIDs : public OptInDispatch {
255257
// cases, depending on specific swizzle type and axes.
256258
void handle(Swizzle2D* swizzle) override {}
257259

260+
void handle(Resize* resize) override {}
261+
258262
IterDomain* getCAIndexConcreteId(IterDomain* id) const;
259263

260264
//! True if an ID is indexable.
@@ -307,6 +311,9 @@ class ContigIDs : public OptInDispatch {
307311
std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;
308312

309313
NonDivisibleSplitDependencies non_divisible_id_info_;
314+
315+
//! IDs that depend on resize output IDs
316+
std::unordered_set<IterDomain*> resize_deps_;
310317
};
311318

312319
} // namespace nvfuser

0 commit comments

Comments
 (0)