Skip to content
Open
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp",
"torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp",
"torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp",
"torch/csrc/jit/codegen/cuda/lower_interleaved_loop.cpp",
"torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp",
"torch/csrc/jit/codegen/cuda/lower_index.cpp",
"torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp",
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2641,7 +2641,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
loop->loopTransformInfo().predicate_peel_stage ==
PredicatePeelStage::Main ||
loop->loopTransformInfo().double_buffer_loop_stage ==
DoubleBufferLoopStage::CircularInitProlog) {
DoubleBufferLoopStage::CircularInitProlog ||
loop->isInterleaveUnit()) {
code_ << " = " << gen_start << "; ";
} else {
// Do not start at the start of the ID when not parallelized. Instead,
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,16 @@ class TORCH_CUDA_CU_API TensorView : public Val {
return skew_double_buffer_loop_;
}

void interleave(int axis, int number_of_units = 4) {
// TODO: move to tensorview.cpp
// and add upfront validation.
maybe_interleave_axis_and_factor_ = std::make_pair(axis, number_of_units);
}

auto getMaybeInterleavedAxisAndFactor() const {
return maybe_interleave_axis_and_factor_;
}

//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are either inputs/outputs of an
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
Expand Down Expand Up @@ -636,6 +646,9 @@ class TORCH_CUDA_CU_API TensorView : public Val {
//! Indicates if the prolog of the double buffer loop of double
//! buffer tensor will be lifted out of the main loop.
bool skew_double_buffer_loop_ = false;

// Loop where the next level of unrolled loops are interleaved.
c10::optional<std::pair<int, int>> maybe_interleave_axis_and_factor_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on the pair

};

//! A simple TensorView builder
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ bool ForLoop::isTrivial() const {
return true;
}

if (isInterleaveUnit()) {
// Index is important in interleave unit.
return false;
}

// Extent-1 loop: for (int i = 0; i < 1; ++i) {
if (start()->isZeroInt() && stop()->isOneInt() && step()->isOneInt()) {
return true;
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ struct LoopTransformInfo {
//! lifted memory address.
bool is_base_index_loop = false;

//! Tracks if this for loop is a unit from an interleaved set of loops.
bool is_interleave_unit = false;

//! Tracks if this for loop is for calculating inductive variable
//! increments.
bool is_increment_loop = false;
Expand All @@ -590,6 +593,12 @@ struct LoopTransformInfo {
return *this;
}

//! Setter API
LoopTransformInfo& interLeaveUnit() {
is_interleave_unit = true;
return *this;
}

// ! Setter API
LoopTransformInfo& incrementLoop() {
is_increment_loop = true;
Expand Down Expand Up @@ -693,6 +702,10 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr {
return loop_transform_info_.is_base_index_loop;
}

bool isInterleaveUnit() const {
return loop_transform_info_.is_interleave_unit;
}

private:
//! Returns if a loop could be unrolled.
bool isUnrollable() const;
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
#include <torch/csrc/jit/codegen/cuda/lower_instrument.h>
#include <torch/csrc/jit/codegen/cuda/lower_interleaved_loop.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
Expand Down Expand Up @@ -317,6 +318,8 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {

doubleBufferInfo().build(fusion_);

interleavedLoopInfo().build(fusion_);

compute_at_map_->allocateIndexVariables();

addressComputeInfo().build(fusion_);
Expand Down Expand Up @@ -359,12 +362,15 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
const auto predicate_peeled =
PredicatePeeling::peelPredicatedLoop(exprs_double_buffered);

const auto exprs_interleaved =
interLeaveDoubleBufferUnrolledLoops(predicate_peeled);

// This pass inserts predicates as well as branches in the code. Up until now
// the code is explicitly single shot for loop based. Need to be careful in
// later passes when doing any kind of insertions in loop nest structure as
// insertions could be on if then or else instead of directly on a for loop.
const auto exprs_unrolled_loops =
UnrollPass::runPass(fusion_, predicate_peeled);
UnrollPass::runPass(fusion_, exprs_interleaved);

const auto exprs_unrolled_mv_loops =
processMisalignedVectorization(exprs_unrolled_loops);
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
#include <torch/csrc/jit/codegen/cuda/lower_fused_reduction.h>
#include <torch/csrc/jit/codegen/cuda/lower_index_hoist.h>
#include <torch/csrc/jit/codegen/cuda/lower_interleaved_loop.h>
#include <torch/csrc/jit/codegen/cuda/lower_mem_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate_elimination.h>
Expand Down Expand Up @@ -175,6 +176,14 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
return sync_map_;
}

auto& interleavedLoopInfo() {
return interleave_info_;
}

const auto& interleavedLoopInfo() const {
return interleave_info_;
}

kir::KernelPerformanceProfile& profile() {
return profile_;
}
Expand Down Expand Up @@ -236,6 +245,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
AddressComputeInfo address_compute_info_;
PredicatePeelingInfo predicate_peeling_info_;
kir::KernelPerformanceProfile profile_;
InterleaveLoopInfo interleave_info_;
std::unordered_set<Split*> divisible_splits_;

// Track which tensor views are inputs or outputs of a vectorized operation
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,12 @@ class BufferUseDefInfo {
maybe_alloc_info.value()->can_use_inner_alias = false;
}

if (input_tv->isDoubleBuffered()) {
// Do not inline re-use double buffered tensors,
// the whole allocated space is valid throughout.
maybe_alloc_info.value()->can_use_inner_alias = false;
}

auto outer_loop_info =
ascendLoopNestToSameLevelAs(maybe_alloc_info.value());

Expand Down
19 changes: 17 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ unsigned int getDoubleBufferAxisPosition(const TensorView* tv) {
// which defines the loop where prefetching is applied. Therefore,
// the CA position must be larger than 0.

TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0);
TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0, tv->toString());

// Unroll must not exist outside of double-buffer axis
auto first_unroll_it = std::find_if(
Expand Down Expand Up @@ -337,7 +337,10 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
}
}

if (stage_depth > 2) {
// Need to insert commits for multi-stage circular buffering
// on the prologs, but do not need to wait for them until
// the main loop.
if (stage_depth > 2 && loop_type_ == DoubleBufferLoopStage::Prolog) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a generic bug fix or is it related to the interleaving transformation?

cloned_top_level_loop_->body().push_back(
IrBuilder::create<kir::CpAsyncCommit>());
}
Expand Down Expand Up @@ -821,6 +824,10 @@ class DoubleBufferInserter : private kir::ExprMutator {
main_loop->iter_domain());
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);

// Make sure the commit is inserted right before the
// cp.async.wait in circular buffering.
bool need_insert_commit = stage_depth > 2;

// Check if a sync has been inserted by WAR sync pass.
auto block_sync_it = std::find_if(
main_loop->body().exprs().rbegin(),
Expand All @@ -832,10 +839,18 @@ class DoubleBufferInserter : private kir::ExprMutator {
// it can just be anywhere in the loop. Chose to
// place at the end arbitrarily.
main_loop->body().insert_after(end_of_loop_expr, cp_async_wait);
if (need_insert_commit) {
main_loop->body().insert_after(
end_of_loop_expr, IrBuilder::create<kir::CpAsyncCommit>());
}
} else {
// If a sync has been inserted, wait needs to be placed
// before the sync.
main_loop->body().insert_before(*block_sync_it, cp_async_wait);
if (need_insert_commit) {
main_loop->body().insert_before(
*block_sync_it, IrBuilder::create<kir::CpAsyncCommit>());
}
Comment on lines +850 to +853
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely following what should be done here, but the above comment on need_insert_commit indicates a commit should be inserted before the wait, but this seems to insert a commit after the wait inserted above. Am I missing something?

}
}

Expand Down
Loading