Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/nvfuser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ if(BUILD_TEST)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu1.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu2.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu3.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_cat.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_compute_with.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_expr_simplifier.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_external_src.cpp)
Expand Down
25 changes: 25 additions & 0 deletions third_party/nvfuser/csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,31 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const CatOp* cat) final {
auto out = gen(cat->output(0));
auto cat_idx = gen(cat->getConcatenatedDomainIndex());

for (const auto i : c10::irange(cat->inputs().size())) {
auto inp = cat->input(i)->as<kir::TensorIndex>();
auto inp_str = gen(inp);
if (i < cat->inputs().size() - 1) {
if (i == 0) {
indent() << "if (";
} else {
indent() << "} else if (";
}
code_ << gen(cat->getPred(i)) << ") {\n";
} else {
// last case doesn't need to be predicated
indent() << "} else {\n";
}

indent() << kTab << out << " = " << gen(inp) << ";\n";
}

indent() << "}\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
16 changes: 13 additions & 3 deletions third_party/nvfuser/csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ bool IterDomainGraph::exprsMap(
}

TORCH_INTERNAL_ASSERT(
first->isA<Merge>() || first->isA<Split>(),
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->isA<Merge>() || first->isA<Split>() || first->isA<Resize>(),
"Merge, split and Expand are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->toString());

auto first_ids = ir_utils::filterByType<IterDomain>(
Expand Down Expand Up @@ -169,6 +169,15 @@ bool IterDomainGraph::exprsMap(
}
}

if (first->isA<Resize>()) {
auto first_expand = first->as<Resize>();
auto second_expand = second->as<Resize>();
if (!first_expand->left()->sameAs(second_expand->left()) ||
!first_expand->right()->sameAs(second_expand->right())) {
return false;
}
}

return true;
}

Expand Down Expand Up @@ -562,8 +571,9 @@ void IterDomainGraph::build(Fusion* fusion) {
consumer_tv->getMaybeRFactorDomain().end()});
for (auto expr : exprs) {
auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
// TODO: Check side effects
TORCH_INTERNAL_ASSERT(
expr->isA<Split>() || expr->isA<Merge>(),
expr->isA<Split>() || expr->isA<Merge>() || expr->isA<Resize>(),
"Wasn't expecting the expression type of:\n",
expr->toString(),
"\nto be an expression defined in an rfactor transformation.");
Expand Down
59 changes: 59 additions & 0 deletions third_party/nvfuser/csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) {
}
}

void OrderedIdInformation::handle(Resize* expand) {
// Find inputs in the active_ids_ vector
const auto in_it =
std::find(active_ids_.begin(), active_ids_.end(), expand->in());

if (in_it == active_ids_.end()) {
return;
}

auto in_pos = std::distance(active_ids_.begin(), in_it);

// Find inputs in the ordered transforms map
const auto in_ordered_it = consistently_ordered_ids_.find(expand->in());

bool in_ordered = in_ordered_it != consistently_ordered_ids_.end();

// Get root ids of the two inputs
const auto in_root_ids_it = id_to_root_ids_.find(expand->in());

TORCH_INTERNAL_ASSERT(
in_root_ids_it != id_to_root_ids_.end(),
"Error replaying transforms in contiguous ID checker.");

const auto& in_root_ids = in_root_ids_it->second;

// Update map for outputs
// Remove inputs from the active_ids_ and insert the output ID
active_ids_[in_pos] = expand->out();

// Not completely certain, but propagating these properties should e
// fine
if (in_ordered) {
consistently_ordered_ids_.emplace(expand->out());
}

if (exclusivelyConsumesRoots(expand->in())) {
exclusively_consumes_roots_.emplace(expand->out());
}

id_to_root_ids_[expand->out()] = in_root_ids;
}

NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
// TODO: Revisit reduction rfactor axes and propagation. Should probably use
// ca_map to propogate non divisibility dependencies across exact map. Still
Expand Down Expand Up @@ -488,6 +530,19 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
{root_domain_.begin(), root_domain_.end()},
{ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto expand = dynamic_cast<Resize*>(expr)) {
expand_deps_.insert(expand->out());
} else {
if (std::any_of(
expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) {
return inp->isA<IterDomain>() &&
expand_deps_.count(inp->as<IterDomain>());
})) {
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
expand_deps_.insert(out);
}
}
}
handle(expr);
}
}
Expand Down Expand Up @@ -560,6 +615,10 @@ void ContigIDs::handle(Merge* merge) {
return;
}

if (expand_deps_.count(merge->out())) {
return;
}

// Now we know merge->out is a contiguously indexable ID

TORCH_INTERNAL_ASSERT(
Expand Down
8 changes: 8 additions & 0 deletions third_party/nvfuser/csrc/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class OrderedIdInformation : public OptInDispatch {

void handle(Swizzle2D* swizzle) override;

void handle(Resize* expand) override;

// Track which root ids were used to generate each iter domain
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
id_to_root_ids_;
Expand Down Expand Up @@ -248,6 +250,10 @@ class ContigIDs : public OptInDispatch {
// cases, depending on specific swizzle type and axes.
void handle(Swizzle2D* swizzle) override {}

// Disable contig indexing as indexing with an expanded ID need to
// get back to its input ID
void handle(Resize* expand) override {}

IterDomain* getCAIndexConcreteId(IterDomain* id) const;

//! True if an ID is indexable.
Expand Down Expand Up @@ -300,6 +306,8 @@ class ContigIDs : public OptInDispatch {
std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;

NonDivisibleSplitDependencies non_divisible_id_info_;

std::unordered_set<IterDomain*> expand_deps_;
};

} // namespace nvfuser
56 changes: 56 additions & 0 deletions third_party/nvfuser/csrc/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ void Expr::dispatch(T handler, Expr* expr) {
ptr(handler)->handle(expr->as<SqueezeOp>());
return;
}
if (expr->isStrictlyA<CatOp>()) {
ptr(handler)->handle(expr->as<CatOp>());
return;
}
if (expr->isStrictlyA<PadOp>()) {
ptr(handler)->handle(expr->as<PadOp>());
return;
}
if (expr->isStrictlyA<SliceOp>()) {
ptr(handler)->handle(expr->as<SliceOp>());
return;
}
if (expr->isStrictlyA<Split>()) {
ptr(handler)->handle(expr->as<Split>());
return;
Expand All @@ -185,6 +197,10 @@ void Expr::dispatch(T handler, Expr* expr) {
ptr(handler)->handle(expr->as<Swizzle2D>());
return;
}
if (expr->isStrictlyA<Resize>()) {
ptr(handler)->handle(expr->as<Resize>());
return;
}
if (expr->isStrictlyA<TransposeOp>()) {
ptr(handler)->handle(expr->as<TransposeOp>());
return;
Expand Down Expand Up @@ -430,6 +446,18 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(expr->as<SqueezeOp>());
return;
}
if (expr->isStrictlyA<CatOp>()) {
ptr(handler)->handle(expr->as<CatOp>());
return;
}
if (expr->isStrictlyA<PadOp>()) {
ptr(handler)->handle(expr->as<PadOp>());
return;
}
if (expr->isStrictlyA<SliceOp>()) {
ptr(handler)->handle(expr->as<SliceOp>());
return;
}
if (expr->isStrictlyA<Split>()) {
ptr(handler)->handle(expr->as<Split>());
return;
Expand All @@ -442,6 +470,10 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(expr->as<Swizzle2D>());
return;
}
if (expr->isStrictlyA<Resize>()) {
ptr(handler)->handle(expr->as<Resize>());
return;
}
if (expr->isStrictlyA<TransposeOp>()) {
ptr(handler)->handle(expr->as<TransposeOp>());
return;
Expand Down Expand Up @@ -811,6 +843,15 @@ void OptOutConstDispatch::handle(const BroadcastOp* stmt) {
void OptOutConstDispatch::handle(const SqueezeOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const CatOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const PadOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const SliceOp* stmt) {
unhandled(stmt);
}

void OptOutConstDispatch::handle(const Split* stmt) {
unhandled(stmt);
Expand All @@ -821,6 +862,9 @@ void OptOutConstDispatch::handle(const Merge* stmt) {
void OptOutConstDispatch::handle(const Swizzle2D* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const Resize* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const TransposeOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -985,6 +1029,15 @@ void OptOutDispatch::handle(BroadcastOp* stmt) {
void OptOutDispatch::handle(SqueezeOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(CatOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(PadOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(SliceOp* stmt) {
unhandled(stmt);
}

void OptOutDispatch::handle(Split* stmt) {
unhandled(stmt);
Expand All @@ -995,6 +1048,9 @@ void OptOutDispatch::handle(Merge* stmt) {
void OptOutDispatch::handle(Swizzle2D* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(Resize* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(TransposeOp* stmt) {
unhandled(stmt);
}
Expand Down
12 changes: 12 additions & 0 deletions third_party/nvfuser/csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ class ShiftOp;
class GatherOp;
class ViewAsScalar;
class ViewOp;
class CatOp;
class PadOp;
class SliceOp;

// Exprs
class Split;
class Merge;
class Swizzle2D;
class Resize;

namespace kir {
class Predicate;
Expand Down Expand Up @@ -168,10 +172,14 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const MmaOp* stmt);
virtual void handle(const BroadcastOp* stmt);
virtual void handle(const SqueezeOp* stmt);
virtual void handle(const CatOp* stmt);
virtual void handle(const PadOp* stmt);
virtual void handle(const SliceOp* stmt);

virtual void handle(const Split* stmt);
virtual void handle(const Merge* stmt);
virtual void handle(const Swizzle2D* stmt);
virtual void handle(const Resize* stmt);
virtual void handle(const TransposeOp* stmt);
virtual void handle(const ExpandOp* stmt);
virtual void handle(const ShiftOp* stmt);
Expand Down Expand Up @@ -241,10 +249,14 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(MmaOp* stmt);
virtual void handle(BroadcastOp* stmt);
virtual void handle(SqueezeOp* stmt);
virtual void handle(CatOp* stmt);
virtual void handle(PadOp* stmt);
virtual void handle(SliceOp* stmt);

virtual void handle(Split* stmt);
virtual void handle(Merge* stmt);
virtual void handle(Swizzle2D* stmt);
virtual void handle(Resize* stmt);
virtual void handle(TransposeOp* stmt);
virtual void handle(ExpandOp* stmt);
virtual void handle(ShiftOp* stmt);
Expand Down
Loading