diff --git a/build_variables.bzl b/build_variables.bzl index 0383024de541..80da8c99495f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -686,6 +686,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp", "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp", "torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp", + "torch/csrc/jit/codegen/cuda/lower_interleaved_loop.cpp", "torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp", diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 76c5bc4ac82b..f09276adb449 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2641,7 +2641,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { loop->loopTransformInfo().predicate_peel_stage == PredicatePeelStage::Main || loop->loopTransformInfo().double_buffer_loop_stage == - DoubleBufferLoopStage::CircularInitProlog) { + DoubleBufferLoopStage::CircularInitProlog || + loop->isInterleaveUnit()) { code_ << " = " << gen_start << "; "; } else { // Do not start at the start of the ID when not parallelized. Instead, diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 44d7eac8065c..c6536c3343f8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -529,6 +529,16 @@ class TORCH_CUDA_CU_API TensorView : public Val { return skew_double_buffer_loop_; } + void interleave(int axis, int number_of_units = 4) { + // TODO: move to tensorview.cpp + // and add upfront validation. + maybe_interleave_axis_and_factor_ = std::make_pair(axis, number_of_units); + } + + auto getMaybeInterleavedAxisAndFactor() const { + return maybe_interleave_axis_and_factor_; + } + //! Transforms the innermost iterdomains according to the given mma swizzle, //! this should be used on the tvs that are either inputs/outputs of an //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to @@ -636,6 +646,9 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! Indicates if the prolog of the double buffer loop of double //! buffer tensor will be lifted out of the main loop. bool skew_double_buffer_loop_ = false; + + // Loop where the next level of unrolled loops are interleaved. + c10::optional> maybe_interleave_axis_and_factor_; }; //! A simple TensorView builder diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index dce6e140d41b..0bffc6b45b87 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -494,6 +494,11 @@ bool ForLoop::isTrivial() const { return true; } + if (isInterleaveUnit()) { + // Index is important in interleave unit. + return false; + } + // Extent-1 loop: for (int i = 0; i < 1; ++i) { if (start()->isZeroInt() && stop()->isOneInt() && step()->isOneInt()) { return true; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 8792315b2f1a..299d18f59d11 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -568,6 +568,9 @@ struct LoopTransformInfo { //! lifted memory address. bool is_base_index_loop = false; + //! Tracks if this for loop is a unit from an interleaved set of loops. + bool is_interleave_unit = false; + //! Tracks if this for loop is for calculating inductive variable //! increments. bool is_increment_loop = false; @@ -590,6 +593,12 @@ struct LoopTransformInfo { return *this; } + //! Setter API + LoopTransformInfo& interLeaveUnit() { + is_interleave_unit = true; + return *this; + } + // ! Setter API LoopTransformInfo& incrementLoop() { is_increment_loop = true; @@ -693,6 +702,10 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { return loop_transform_info_.is_base_index_loop; } + bool isInterleaveUnit() const { + return loop_transform_info_.is_interleave_unit; + } + private: //! Returns if a loop could be unrolled. bool isUnrollable() const; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index eeb6395b5caa..f00a09558ec5 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -317,6 +318,8 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { doubleBufferInfo().build(fusion_); + interleavedLoopInfo().build(fusion_); + compute_at_map_->allocateIndexVariables(); addressComputeInfo().build(fusion_); @@ -359,12 +362,15 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { const auto predicate_peeled = PredicatePeeling::peelPredicatedLoop(exprs_double_buffered); + const auto exprs_interleaved = + interLeaveDoubleBufferUnrolledLoops(predicate_peeled); + // This pass inserts predicates as well as branches in the code. Up until now // the code is explicitly single shot for loop based. Need to be careful in // later passes when doing any kind of insertions in loop nest structure as // insertions could be on if then or else instead of directly on a for loop. const auto exprs_unrolled_loops = - UnrollPass::runPass(fusion_, predicate_peeled); + UnrollPass::runPass(fusion_, exprs_interleaved); const auto exprs_unrolled_mv_loops = processMisalignedVectorization(exprs_unrolled_loops); diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 18f514933a9d..b2ef8e315ea9 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -175,6 +176,14 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { return sync_map_; } + auto& interleavedLoopInfo() { + return interleave_info_; + } + + const auto& interleavedLoopInfo() const { + return interleave_info_; + } + kir::KernelPerformanceProfile& profile() { return profile_; } @@ -236,6 +245,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { AddressComputeInfo address_compute_info_; PredicatePeelingInfo predicate_peeling_info_; kir::KernelPerformanceProfile profile_; + InterleaveLoopInfo interleave_info_; std::unordered_set divisible_splits_; // Track which tensor views are inputs or outputs of a vectorized operation diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index ef12cce8fd46..17b5ffab8dab 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -650,6 +650,12 @@ class BufferUseDefInfo { maybe_alloc_info.value()->can_use_inner_alias = false; } + if (input_tv->isDoubleBuffered()) { + // Do not inline re-use double buffered tensors, + // the whole allocated space is valid throughout. + maybe_alloc_info.value()->can_use_inner_alias = false; + } + auto outer_loop_info = ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 7f0c232d7abd..77330e8334b4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -17,7 +17,7 @@ unsigned int getDoubleBufferAxisPosition(const TensorView* tv) { // which defines the loop where prefetching is applied. Therefore, // the CA position must be larger than 0. - TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0); + TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0, tv->toString()); // Unroll must not exist outside of double-buffer axis auto first_unroll_it = std::find_if( @@ -337,7 +337,10 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } } - if (stage_depth > 2) { + // Need to insert commits for multi-stage circular buffering + // on the prologs, but do not need to wait for them until + // the main loop. + if (stage_depth > 2 && loop_type_ == DoubleBufferLoopStage::Prolog) { cloned_top_level_loop_->body().push_back( IrBuilder::create()); } @@ -821,6 +824,10 @@ class DoubleBufferInserter : private kir::ExprMutator { main_loop->iter_domain()); auto cp_async_wait = IrBuilder::create(stage_depth - 2); + // Make sure the commit is inserted right before the + // cp.async.wait in circular buffering. + bool need_insert_commit = stage_depth > 2; + // Check if a sync has been inserted by WAR sync pass. auto block_sync_it = std::find_if( main_loop->body().exprs().rbegin(), @@ -832,10 +839,18 @@ class DoubleBufferInserter : private kir::ExprMutator { // it can just be anywhere in the loop. Chose to // place at the end arbitrarily. main_loop->body().insert_after(end_of_loop_expr, cp_async_wait); + if (need_insert_commit) { + main_loop->body().insert_after( + end_of_loop_expr, IrBuilder::create()); + } } else { // If a sync has been inserted, wait needs to be placed // before the sync. main_loop->body().insert_before(*block_sync_it, cp_async_wait); + if (need_insert_commit) { + main_loop->body().insert_before( + *block_sync_it, IrBuilder::create()); + } } } diff --git a/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.cpp b/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.cpp new file mode 100644 index 000000000000..885fb819decd --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.cpp @@ -0,0 +1,717 @@ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Note: [Loop Interleaving]: +// This pass is trying to implement a simple yet useful loop structure +// optimization that tries to interleave sub iterations of unrolled loops. +// With an example: +// +// Before transform: +// for i0 in 0..4 +// expr1 +// for i1 in 0..8 +// expr2 +// for i2 in 0..4 +// expr3 +// After transform: +// for i0 in {0} +// expr1 +// for i1 in {0,1} +// expr2 +// for i2 in {0} +// expr3 +// for i0 in {1} +// expr1 +// for i1 in {2,3} +// expr2 +// ... +// +// To simplify the initial implementation, an outer serial loop is assumed, as +// an indicator to define at which loop nest level to start interleaving, so +// the actual transform looks like: (some terminology defined inline) +// Before transform: +// for i in ... // This outer serial loop is called "main loop" in this +// pass +// for i0 in 0..4 // Each of these unrolled loops is called a "subloop" of +// the "main loop" +// expr1 +// for i1 in 0..8 +// expr2 +// for i2 in 0..4 +// expr3 +// After transform: +// for i in ... +// for i0 in {0} // Each of these sub-iterations is called an "interleave +// unit" +// expr1 +// for i1 in {0,1} +// expr2 +// for i2 in {0} +// expr3 +// for i0 in {1} +// expr1 +// for i1 in {2,3} +// expr2 +// ... +// +// This optimization is controlled by scheduler through interface: +// tv->interleave(pos, factor), +// where `pos` is the position of the iterdomain +// that corresponds to the main loop, and all the subloops are assumed to be at +// the immediate next position. +// e.g. +// tv[Io, Ii] -> interleave(0, pos); +// means that the "main loop" is selected to be the loop that is loop mapped to +// Io, and Ii is assumed to be map to one of the "sub loops". +// +// The term `factor` defines the number of "interleave units" to split each "sub +// loop" +// into, in a best effort manner, with each unit size `ceilDiv(loop_extent, +// factor)`. +// +// E.g. if the factor is 4 +// subloop `for i in 0..8` becomes: +// `for i in 0..2` +// `for i in 2..4` +// `for i in 4..6` +// `for i in 6..8` +// subloop `for i in 0..7` becomes: +// `for i in 0..2` +// `for i in 2..4` +// `for i in 4..6` +// `for i in 6..7` +// subloop `for i in 0..6` becomes: +// `for i in 0..2` +// `for i in 2..4` +// `for i in 4..6` +// +// All the subloops are assumed to be constant sized since they need to be +// unrolled +// for this optimization to be meaningful. +namespace { + +int64_t ceilDiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +}; + +//! Returns the next level unrolled loop that is within the given +//! main loop on a tensorview. Returns a c10::nullopt if the unrolled +//! loop cannot be found. +c10::optional getMaybeSubloop( + TensorView* tv, + IterDomain* main_loop) { + bool main_loop_found = false; + const auto& ca_map = GpuLower::current()->caMap(); + + for (auto leaf_id : tv->domain()->domain()) { + if (main_loop_found && !leaf_id->isParallelized()) { + return ca_map->getConcreteMappedID(leaf_id, IdMappingMode::LOOP); + } + main_loop_found = main_loop_found || + ca_map->areMapped(leaf_id, main_loop, IdMappingMode::LOOP); + } + + return c10::nullopt; +} + +} // namespace + +void InterleaveLoopInfo::build(Fusion* fusion) { + fusion_ = fusion; + auto used_math_vals = fusion->usedMathVals(); + auto filtered_used_math_vals = + ir_utils::filterByType(used_math_vals); + + // Cache used tvs for multiple visit. + used_tvs_ = {filtered_used_math_vals.begin(), filtered_used_math_vals.end()}; + + // Collect loop information from fusion + collectInterleaveMainLoops(); + collectInterleavedSubLoops(); + + // Validate interleaved expressions for data consistency + validate(); +} + +void InterleaveLoopInfo::collectInterleaveMainLoops() { + for (auto tv : used_tvs_) { + auto maybe_main_axis = tv->getMaybeInterleavedAxisAndFactor(); + if (maybe_main_axis.has_value()) { + auto concrete_main_loop_id = + GpuLower::current()->caMap()->getConcreteMappedID( + tv->axis(maybe_main_axis.value().first), IdMappingMode::LOOP); + + // Create new record for this loop id if not found + if (!concrete_main_loop_to_interleaved_tv_.count(concrete_main_loop_id)) { + // Create record space to later collect the interleaved tensors + // and the subloops. + concrete_main_loop_to_subloop_map_.insert( + std::make_pair(concrete_main_loop_id, ConcreteIdVector())); + concrete_main_loop_to_interleaved_tv_.insert( + std::make_pair(concrete_main_loop_id, TensorViewVector())); + + // Record the interleave factor for this main loop. see [Loop + // Interleaving]. + concrete_main_loop_to_number_of_units_.insert(std::make_pair( + concrete_main_loop_id, maybe_main_axis.value().second)); + } + } + } +} + +bool InterleaveLoopInfo::isMainLoop(IterDomain* id) { + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::LOOP); + return concrete_main_loop_to_interleaved_tv_.count(concrete_id); +} + +bool InterleaveLoopInfo::isSubLoopOf( + IterDomain* id, + IterDomain* concrete_main_id) { + auto it = concrete_main_loop_to_subloop_map_.find(concrete_main_id); + TORCH_INTERNAL_ASSERT( + it != concrete_main_loop_to_subloop_map_.end(), "Invalid main loop"); + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::LOOP); + return it->second.has(concrete_id); +} + +void InterleaveLoopInfo::insertEntry( + TensorView* tv, + IterDomain* main_loop, + IterDomain* sub_loop) { + auto concrete_main_loop = GpuLower::current()->caMap()->getConcreteMappedID( + main_loop, IdMappingMode::LOOP); + auto concrete_sub_loop = GpuLower::current()->caMap()->getConcreteMappedID( + sub_loop, IdMappingMode::LOOP); + + // Insert sub loops from this tv + auto main_loop_entry_it = + concrete_main_loop_to_subloop_map_.find(concrete_main_loop); + TORCH_INTERNAL_ASSERT( + main_loop_entry_it != concrete_main_loop_to_subloop_map_.end(), + "unknown main loop: ", + main_loop->toString(), + " (", + concrete_main_loop->toString(), + ")"); + main_loop_entry_it->second.pushBack(concrete_sub_loop); + + // Insert interleaved tvs. + auto tv_entry_it = + concrete_main_loop_to_interleaved_tv_.find(concrete_main_loop); + TORCH_INTERNAL_ASSERT( + tv_entry_it != concrete_main_loop_to_interleaved_tv_.end(), + "unknown main loop: ", + main_loop->toString(), + " (", + concrete_main_loop->toString(), + ")"); + tv_entry_it->second.pushBack(tv); +} + +void InterleaveLoopInfo::collectInterleavedSubLoops() { + for (auto tv : used_tvs_) { + IterDomain* main_loop = nullptr; + for (auto leaf_id : tv->domain()->domain()) { + if (main_loop == nullptr) { + if (isMainLoop(leaf_id)) { + main_loop = leaf_id; + auto maybe_subloop = getMaybeSubloop(tv, leaf_id); + TORCH_INTERNAL_ASSERT( + maybe_subloop.has_value(), + tv->toString(), + " cannot be interleaved within ", + leaf_id); + insertEntry(tv, main_loop, maybe_subloop.value()); + } + } else { + // main loop already found. There should be no more + // main loop in this tensor + TORCH_INTERNAL_ASSERT( + !isMainLoop(leaf_id), + tv, + "has nested main loop ", + main_loop->toString(), + " and ", + leaf_id->toString(), + " which is not yet supported"); + } + } + } +} + +// Validation of double buffering topology of interleaved expressions: +// see [Supported Interleaving Cases] +void InterleaveLoopInfo::validate() { + // Validate expression consistency after interleaving + for (auto& main_loop_entry : concrete_main_loop_to_interleaved_tv_) { + validateMainLoop(main_loop_entry.first, main_loop_entry.second); + } +} + +// Returns true if the given tv is an "exit tv", +// see [Supported Interleaving Cases]. +bool InterleaveLoopInfo::isExitTv( + TensorView* tv, + const TensorViewVector& interleaved_tvs) { + // Output is always an exit + if (tv->isFusionOutput()) { + return true; + } + + for (auto use : fusion_->unordered_uses(tv)) { + // Check if any immediate consumer of tv is interleaved. + for (auto consumer_tv : + ir_utils::filterByType(use->outputs())) { + if (interleaved_tvs.has(consumer_tv)) { + return false; + } + } + } + + // No immediate consumer of tv is interleaved so the tv is an exit tv. + return true; +} + +void InterleaveLoopInfo::validateMainLoop( + IterDomain* concrete_main_loop, + const TensorViewVector& interleaved_tvs) { + // [Supported Interleaving Cases] + // All the expressions that are inside the main loop or subloop can + // only be 3 cases: + // 1. It's double/circular buffered across a loop that's either at or on the + // outer + // loop nest than the main loop. E.g. + // for i in ... // loop 1 + // for j in ... // loop 2 (interleave main loop) + // for k in ... // loop 3 (interleave sub loop) + // tv0 [i%3*buffersize + ... ] = ...; + // tv0 is circular buffered around loop1, so interleaving loop 3 + // with any other serial loops within loop2 will not make any consumer + // of tv0 use the wrong value. + // No guarantee on the producers of tv0 from this though, + // which relies on the same check being run on them as well to ensure + // safety. + // + // 2. It's inlined into the subloop. + // Eg. + // for i in ... // loop1 (interleave main loop) + // for j in ... // loop2 (interleave sub loop) + // for k in ... // loop3 + // tv0[k] = ... + // for w in ... // loop4 + // ... = t0[w] + // The inlining semantically means that the consumer of tv0 above is + // within loop2, so interleaving loop 2 with other loops within loop1 + // should not cause the consumer of tv0 to read wrong values, as they + // are essentially not changed. + // + // 3. It's not a producer of any other interleaved tv's, + // i.e. it is an "exit tv". + // for i in ... // loop1 (interleave main loop) + // for j in ... // loop2 (interleave subloop 1) + // tv0[j] = ... + // for k in ... // loop3 (interleave subloop 2) + // tv1[j] = ... + // + // for m in ... + // ... = tv0[m] + tv1[m]; + // + // In this case tv0 and tv1 are producing values that are used outside + // of any of the expressions that are interleaved, so the interleaving + // of loop2 and loop3 should have no effect on the semantic. + for (auto tv : interleaved_tvs.vector()) { + if (isExitTv(tv, interleaved_tvs)) { + // Exit tv computation can be interleaved by Point 3 above. + continue; + } + + // Double buffered tv doesn't need to be checked, see Point 2 above: + if (tv->isDoubleBuffered() || tv->isCircularBuffered()) { + auto db_axis = + GpuLower::current()->doubleBufferInfo().getDoubleBufferAxis(tv); + + // Check that the double buffer axis is at or on the left of + // the main loop. + bool can_interleave = false; + + // Iterating over the leaf domains from the left + for (auto id : tv->domain()->domain()) { + if (id == db_axis) { + // If we see double buffer axis first then + // it's double buffered on the outer loop. + // So it can be interleaved. + can_interleave = true; + break; + } else if (GpuLower::current()->caMap()->areMapped( + id, concrete_main_loop, IdMappingMode::LOOP)) { + // If we see main loop before seeing the double buffer axis, + // it cannot be proven safe to interleave by double buffering + // but the other two points might apply. + can_interleave = false; + } + } + + if (can_interleave) { + continue; + } + } + + // If Point3 and Point2 didn't apply at this point, + // then Point1 has to apply in order for this interleaving to be valid. + // TODO: + // Maybe in follow ups more supported patterns could be added. + + // Check that the subloop is on the left of CA axis: + auto& concrete_subloops = + concrete_main_loop_to_subloop_map_.at(concrete_main_loop); + bool subloop_found = false; + for (auto id_it = tv->domain()->domain().begin(); + id_it != tv->domain()->domain().begin() + tv->getComputeAtPosition(); + id_it++) { + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + *id_it, IdMappingMode::LOOP); + if (concrete_subloops.has(concrete_id)) { + subloop_found = true; + break; + } + } + TORCH_INTERNAL_ASSERT( + subloop_found, + "unsupported interleaved tv ", + tv->toString(), + " it needs to be either double buffered, or an exit of interleaved region or inlined beyond subloops"); + } +} + +namespace { + +// A data structure collecting the parameters when realizing the interleaving. +struct InterLeaveConfig { + // Total number of units, aka. interleave factor, + // see [Loop Interleaving]. + int64_t number_of_units = 1; + + // Evaluated loop extent of each sub loop. + std::unordered_map concrete_id_to_extent_; +}; + +//! The loop interleaving pass that implements the interleaving +//! transform, see [Loop Interleaving]. +class LoopInterLeaver : kir::ExprMutator { + public: + static std::vector run(std::vector exprs) { + // Interleave main loops one at a time. + for (auto& it : GpuLower::current() + ->interleavedLoopInfo() + .concreteMainLoopToSubloopMap()) { + LoopInterLeaver interleaver; + interleaver.concrete_main_loop_ = it.first; + + interleaver.concrete_subloop_set_ = std::unordered_set( + it.second.vector().begin(), it.second.vector().end()); + interleaver.traverseAndInsert(exprs); + exprs = interleaver.exprs_; + } + return exprs; + } + + private: + using kir::ExprMutator::handle; + + void handle(kir::ForLoop* fl) final { + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + fl->iter_domain(), IdMappingMode::LOOP); + + // For double buffered loops, only interleave the main stage. + if (concrete_main_loop_ == concrete_loop_id && + fl->doubleBufferLoopStage() == DoubleBufferLoopStage::Main) { + handleMainLoop(fl); + } else { + kir::ExprMutator::handle(fl); + } + } + + // Returns true if the expression is a subloop to be interleaved + // see [Loop Interleaving]. + bool isInterleavedSubloop(Expr* expr) { + if (auto loop = dynamic_cast(expr)) { + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + loop->iter_domain(), IdMappingMode::LOOP); + if (concrete_subloop_set_.count(concrete_loop_id) && + // Do not interleave double buffer epilogs + loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Epilog && + + // Do not interleave any index computation expressions + !loop->loopTransformInfo().is_base_index_loop && + !loop->loopTransformInfo().is_increment_loop) { + return true; + } + } + return false; + } + + // Remove the original subloops once the interleaved + // versions have been inserted. + void clearSubLoops( + std::vector& interleaved_subloops, + kir::ForLoop* main_loop) { + for (auto fl : interleaved_subloops) { + registerRemove(fl, &main_loop->body()); + } + interleaved_subloops.clear(); + } + + // Realize the interleaving with the given for loop + // as the main loop, see [Loop Interleaving]. + // + // [Loop Interleaving Impl] + // The implementation pass goes as the below example: + // + // for i in ... // main loop + // for j in ... // sub loop1 + // ... + // for k in ... // sub loop2 + // ... + // __syncthread(); + // for m in ... // sub loop3 + // ... + // for n in ... // sub loop4 + // ... + // + // This function loops through the body of the main loop + // and puts all the subloops encountered in `interleaved_subloops` + // vector. + // Whenever it sees an expression that is *not* a interleaved subloop, e.g. + // the syncthreads in the above example, the currently collected + // `interleaved_subloops`, i.e. loop1 and loop2 in this case, are + // emitted as interleaved units and the pass continues with an empty + // `interleaved_subloops` vector. + // As a result, in this example, sub loop1 and sub loop2 are interleaved + // while sub loop3 and sub loop4 are interleaved. + void handleMainLoop(kir::ForLoop* fl) { + // Collect the subloops encountered when looping + // over the main loop expressions. + std::vector interleaved_subloops; + + // Loop over the main loop body. + for (auto expr : fl->body().exprs()) { + if (auto loop = dynamic_cast(expr)) { + if ( + // Usually not useful to involve double buffer prologs + // and epilogs in the interleaving. + !isProlog(loop->doubleBufferLoopStage()) && + loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Epilog && + // Check if this expression is a subloop + isInterleavedSubloop(expr)) { + // Collect this sub loop to be realized later, see details above. + interleaved_subloops.push_back(expr->as()); + continue; + } + } + + // Main loop may have allocation expressions that can be safe + // to just continue collecting the subloop across as the interleave + // units will be realized after this expression, which means the + // allocation is still valid. + if (expr->isA()) { + continue; + } + + // This is the point where we see an expression that is *not* an + // interleaved subloop that we are collecting, so emit the currently + // collected interleaved subloops as interleaved units. + // And clear the collected vector before proceeding. + if (!interleaved_subloops.empty()) { + realizeInterleavedSubloops(expr, interleaved_subloops, true, fl); + clearSubLoops(interleaved_subloops, fl); + } + } + + // It's possible, actually common that all exprs within + // the main loop are subloops, so we will need to run + // another realization step after visiting the whole main + // loop. + if (!interleaved_subloops.empty()) { + realizeInterleavedSubloops( + fl->body().exprs().back(), interleaved_subloops, false, fl); + clearSubLoops(interleaved_subloops, fl); + } + } + + // Performs a deep loopnest clone if the expression + // is a loop nest. + // TODO: use common infra + Expr* cloneMaybeLoopNest(Expr* expr) { + auto fl = dynamic_cast(expr); + if (!fl) { + return expr; + } + + TORCH_INTERNAL_ASSERT(!expr->isA(), "unsupported"); + auto cloned_fl = IrBuilder::create(fl); + + for (auto loop_expr : fl->body().exprs()) { + cloned_fl->body().push_back(cloneMaybeLoopNest(loop_expr)); + } + + return cloned_fl; + } + + void handle(kir::IfThenElse*) final { + TORCH_INTERNAL_ASSERT( + false, "LoopInterleaving: no support yet post IfThenElse lowering"); + } + + // Emit the currently collected subloops as interleaved units, + // see [Loop Interleaving Impl]. + void realizeInterleavedSubloops( + // A insertion reference point + Expr* insert_point, + // Subloops to interleave. + std::vector sub_loops, + // Insert interleave units before insertion point + // if true, after if false. + bool insert_before, + // Main loop to interleave within. + kir::ForLoop* main_loop) { + // Container to collect the interleave units in interleaved order. + std::vector interleave_units; + + // Populate parameters on interleaving these sub loops. + auto config = getInterleaveConfig(main_loop, sub_loops); + + // Repeat for number_of_units times, each time creating + // an interleave unit for each subloop. + for (int idx : c10::irange(config.number_of_units)) { + // Loop over each sub loop + for (auto sub_loop : sub_loops) { + // Collect concrete id and extent + auto concrete_loop_id = + GpuLower::current()->caMap()->getConcreteMappedID( + sub_loop->iter_domain(), IdMappingMode::LOOP); + + auto concrete_extent = + config.concrete_id_to_extent_.at(concrete_loop_id); + + // Calculate size of this unit + auto interleave_unit = ceilDiv(concrete_extent, config.number_of_units); + + // Set start and stop of this unit, + // stop needs to be the minimum of start+size and original extent + // to avoid out running the orignal loop. + int start_idx = idx * interleave_unit; + auto stop_idx = std::min(start_idx + interleave_unit, concrete_extent); + + // No longer need to generate more of this sub loop if + // start is already out of bound. + if (start_idx < concrete_extent) { + auto start_val = SimplifyingIrBuilder::create(start_idx); + auto stop_val = SimplifyingIrBuilder::create(stop_idx); + interleave_units.push_back( + makeInterleavedUnit(sub_loop, start_val, stop_val)); + } + } + } + + if (insert_before) { + for (auto unit : interleave_units) { + registerInsertBefore(insert_point, unit, &main_loop->body()); + } + } else { + // Need to insert in reverse order when inserting after in order + // to maintain the original order defined in interleave_units. + for (auto it = interleave_units.rbegin(); it != interleave_units.rend(); + it++) { + registerInsertAfter(insert_point, *it, &main_loop->body()); + } + } + } + + // Make an interleaved unit of the given sub loop according to the given + // start and stop offset. + kir::ForLoop* makeInterleavedUnit(kir::ForLoop* fl, Val* start, Val* stop) { + // Create an outer loop with the same loop expressions but + // different start and stop. + auto outer_loop = IrBuilder::create( + fl->iter_domain(), + fl->index(), + start, + stop, + fl->step(), + fl->vectorize(), + fl->vectorize_shift(), + fl->isUnrolled(), + fl->loopTransformInfo().interLeaveUnit()); + + for (auto expr : fl->body().exprs()) { + outer_loop->body().push_back(cloneMaybeLoopNest(expr)); + } + + return outer_loop; + } + + // Collect info needed to realize interleaved loop, + // see [Loop Interleaving Impl]. + InterLeaveConfig getInterleaveConfig( + kir::ForLoop* main_loop, + const std::vector sub_loops_) { + TORCH_INTERNAL_ASSERT( + !sub_loops_.empty(), "Cannot generate config for empty subloops"); + InterLeaveConfig interleave_config; + ExpressionEvaluator const_evaluator(sub_loops_[0]->iter_domain()->fusion()); + + for (auto fl : sub_loops_) { + auto maybe_value = const_evaluator.evaluate(fl->stop()); + TORCH_INTERNAL_ASSERT( + maybe_value.has_value(), "non constant interleaving not supported"); + auto value = maybe_value.value().as(); + + auto concrete_loop_domain = + GpuLower::current()->caMap()->getConcreteMappedID( + fl->iter_domain(), IdMappingMode::LOOP); + + // Collect concrete extents of each of the subloops. + interleave_config.concrete_id_to_extent_[concrete_loop_domain] = value; + } + + // Calculate interleave factor, simple heuristic as ceilDiv(max, min): + interleave_config.number_of_units = + GpuLower::current() + ->interleavedLoopInfo() + .concreteMainLoopToFactorMap() + .at(GpuLower::current()->caMap()->getConcreteMappedID( + main_loop->iter_domain(), IdMappingMode::LOOP)); + + return interleave_config; + } + + private: + // Marks the current main loop this pass + // is processing. + IterDomain* concrete_main_loop_ = nullptr; + + // Set of subloop concrete IterDomains that will + // be interleaved within main loop. + std::unordered_set concrete_subloop_set_; +}; +} // namespace + +std::vector interLeaveDoubleBufferUnrolledLoops( + const std::vector& exprs) { + return LoopInterLeaver::run(exprs); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.h b/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.h new file mode 100644 index 000000000000..7150607f3d5a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_interleaved_loop.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Keeps track of loops that will be interleaved, see +//! [Loop Interleaving]. +class InterleaveLoopInfo { + using ConcreteIdVector = VectorOfUniqueEntries; + using TensorViewVector = VectorOfUniqueEntries; + + public: + //! Collect info by traversing fusion expressions. + void build(Fusion* fusion); + + //! Validate data consistency after interleaving. + //! see [Supported Interleaving Cases]. + void validate(); + + //! See comment on concrete_main_loop_to_subloop_map_ + const auto& concreteMainLoopToSubloopMap() const { + return concrete_main_loop_to_subloop_map_; + } + + //! See comment on concrete_main_loop_to_number_of_units_ + const auto& concreteMainLoopToFactorMap() const { + return concrete_main_loop_to_number_of_units_; + } + + private: + //! Build phase 1: check all the tv's for + //! main_loops where subloops are interleaved. + void collectInterleaveMainLoops(); + + //! Build phase2: collect all tv's that are + //! computed within interleaved loops. + void collectInterleavedSubLoops(); + + //! Register a (main_loop, sub_loop) pair collected from + //! tv, see [Loop Interleaving]. + void insertEntry(TensorView* tv, IterDomain* main_loop, IterDomain* sub_loop); + + //! Returns true if the given id is loop mapped with + //! an interleaving main loop, see [Loop Interleaving]. + bool isMainLoop(IterDomain* id); + + //! Returns true if the id is loop mapped to a sub loop + //! within a main loop mapped to concrete_main_id. + //! see also [Loop Interleaving]. + bool isSubLoopOf(IterDomain* id, IterDomain* concrete_main_id); + + //! Validate data consistency after interleaving. + //! see [Supported Interleaving Cases]. + void validateMainLoop( + IterDomain* main_loop, + const TensorViewVector& interleaved_tvs); + + //! Validation utility: + //! see [Supported Interleaving Cases]. + bool isExitTv(TensorView* tv, const TensorViewVector& interleaved_tvs); + + private: + //! Keeps track of interleaving main loops and the + //! interleaved subloops within, see [Loop Interleaving]. + std::unordered_map + concrete_main_loop_to_subloop_map_; + + //! Keeps track of the interleaving main loops and + //! all the tensors that are *produced* within + //! the interleaved subloops associated with + //! each interleaving main loop. + std::unordered_map + concrete_main_loop_to_interleaved_tv_; + + //! Keeps track of the interleaving factor of each + //! interleaving main loop. see [Loop Interleaving]. + std::unordered_map concrete_main_loop_to_number_of_units_; + + //! Short-cut to the fusion this info keeps track of. + Fusion* fusion_ = nullptr; + + //! Cached used math vals from fusion_; + std::vector used_tvs_; +}; + +void validateInterleaving(Fusion* fusion); + +std::vector interLeaveDoubleBufferUnrolledLoops( + const std::vector& exprs); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index ee656e0b74e5..b0e53230cc5b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -925,7 +925,8 @@ void validateMmaTensors(MmaOp* mma) { // CA axis are constant sized to ensure early detection of // invalid mma schedules. ((id->isBroadcast() || id->extent()->isConstInt()) && - id->getParallelType() == ParallelType::Serial); + id->getParallelType() == ParallelType::Serial) || + id->isThread(); }), "All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n", tv); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 484f755abe1e..48e9083cb80b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -392,7 +392,10 @@ void scheduleMatmul( scheduler_utils::transformPropagateToAllFrom(cc, -1); // Schedule warp tile - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(cc, gemm_tile); + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kw Mwo Nwo Mw Nw (Mi Ni Ki)] + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( + cc, gemm_tile, true); // Propagate warp tile to main loop and epilog/output tvs scheduler_utils::BoundedDirectionalTransformPropagator::bothWays( @@ -442,8 +445,8 @@ void scheduleMatmul( b->computeAt(cc, 3); // Main Loop: - acr->computeAt(cc, -6); - bcr->computeAt(cc, -6); + acr->computeAt(cc, -8); + bcr->computeAt(cc, -8); // Add mma swizzle: // TODO: this section goes to a separate matmul util, @@ -491,12 +494,18 @@ void scheduleMatmul( acr->axis(-1)->parallelize(ParallelType::Vectorize); bcr->axis(-1)->parallelize(ParallelType::Vectorize); - // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kw Mwo Nwo Mw Nw (Mi Ni Ki)] cc->axis(0)->parallelize(ParallelType::BIDx); cc->axis(1)->parallelize(ParallelType::BIDy); - cc->axis(3)->parallelize(ParallelType::TIDz); - cc->axis(4)->parallelize(ParallelType::TIDy); + cc->axis(4)->parallelize(ParallelType::TIDz); + cc->axis(5)->parallelize(ParallelType::TIDy); + + scheduler_utils::parallelizeAllLike( + cc, + -1, + {acr, bcr, ab, bb, a, b}, + {ParallelType::TIDy, ParallelType::TIDz}); // Propagate mma output swizzle and parallelization down the DAG if (params.double_buffer_options.double_buffer_smem_write) { @@ -554,6 +563,13 @@ void scheduleMatmul( if (params.peel_main_loop) { cc->peelPredicatedLoop(2); } + + // Only interleave if using cp.async and + // all the shared memory is double buffered. + if (params.async_gmem_load_operands && + params.double_buffer_options.double_buffer_smem_write) { + cc->interleave(2, 2); + } } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index d985da926354..c56c35313fa7 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -1622,7 +1622,10 @@ BroadcastMultipleInformation getBroadcastMultiples( namespace matmul_utils { -void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { +void scheduleWarpTileWithReduction( + TensorView* tv, + MatMulTileOptions tile, + bool serial_k_loop_first) { // Assumes // [M, N, K] auto cta_tile = tile.cta_tile; @@ -1656,9 +1659,16 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { // -8 -7 -6 -5 -4 -3 -2 -1 // [Mwo Mw Mi Nwo Nw Ni Ko Ki] - tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + if (serial_k_loop_first) { + tv->reorder({{-8, -7}, {-7, -5}, {-6, -3}, {-5, -6}, {-3, -2}, {-2, -8}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [ Ko Mwo Nwo Mw Nw Mi Ni Ki] + } else { + tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + } + } else { // Split K over warp case: // Main difference is that an additional diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 373a879f740d..c41b14598c0c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -354,7 +354,14 @@ TORCH_CUDA_CU_API void scheduleContiguousVectorLoad( //! TODO: rewrite this one with makeTile TORCH_CUDA_CU_API void scheduleWarpTileWithReduction( TensorView* tv, - MatMulTileOptions tile); + MatMulTileOptions tile, + // Will put the inner unrolled k loop on the left + // most position to facilitate loop interleaving + // if serial_k_loop_first is true. + // TODO: should probably re-write all the test + // cases that nees this to be false and then + // remove this parameter. + bool serial_k_loop_first = false); //! Schedule utility for mma output in matmul main loop: //! Realize the hierarchical tiling based on the given tiling options diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index e2ed698356ae..ab6582db551b 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -219,7 +219,9 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) has_swizzle_op_(src->has_swizzle_op_), lift_read_address_(src->lift_read_address_), lift_write_address_(src->lift_write_address_), - skew_double_buffer_loop_(src->skew_double_buffer_loop_) { + skew_double_buffer_loop_(src->skew_double_buffer_loop_), + maybe_interleave_axis_and_factor_( + src->maybe_interleave_axis_and_factor_) { for (const auto id : src->axesToSwizzle()) { axes_to_swizzle_.push_back(ir_cloner->clone(id)); } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index f3783aa57527..6009c38cc480 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -26171,6 +26171,44 @@ TEST_F(NVFuserTest, FusionSimpleMemHoisting_CUDA) { testValidate(&fusion, outputs, {t0}, {t0}, __LINE__, __FILE__); } +// Initial test case for interleaving unrolled loops +TEST_F(NVFuserTest, FusionSimpleInterleaving1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeContigTensor(1); + + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + + fusion.addOutput(tv3); + + tv3->split(0, 4); + tv3->split(0, 8); + tv3->reorder({{1, 2}, {2, 1}}); + tv0->computeAt(tv3, 1); + tv2->computeAt(tv3, 2); + + tv1->doubleBuffer(); + tv2->doubleBuffer(); + + tv3->interleave(0); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({256}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)