diff --git a/third_party/nvfuser/CMakeLists.txt b/third_party/nvfuser/CMakeLists.txt index 9e4ffaf1a066..5b7425ceeb10 100644 --- a/third_party/nvfuser/CMakeLists.txt +++ b/third_party/nvfuser/CMakeLists.txt @@ -46,6 +46,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/lower_index_compute.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp + ${NVFUSER_SRCS_DIR}/id_e_graph.cpp ${NVFUSER_SRCS_DIR}/ir_base_nodes.cpp ${NVFUSER_SRCS_DIR}/ir_builder.cpp ${NVFUSER_SRCS_DIR}/ir_cloner.cpp @@ -351,6 +352,7 @@ if(BUILD_TEST) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_transpose.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_utils.cpp) + list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_id_e_graph.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_indexing_ops.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_indexing.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_gather_ops.cpp) diff --git a/third_party/nvfuser/csrc/id_e_graph.cpp b/third_party/nvfuser/csrc/id_e_graph.cpp new file mode 100644 index 000000000000..9b14684e3940 --- /dev/null +++ b/third_party/nvfuser/csrc/id_e_graph.cpp @@ -0,0 +1,298 @@ +#include "id_e_graph.h" + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void IterDomainEGraph::initGraph() { + // Initialize a partition of all IterDomains in the Fusion, initially + // containing all singleton classes. + for (auto v : fusion_.vals()) { + if (v->getValType() == ValType::IterDomain) { + auto id = reinterpret_cast(v); + all_ids_.push_back(id); + all_extents_.insert(id->extent()); + } + } + id_partition_ = + std::unique_ptr>(new UnionFind(all_ids_)); + extent_partition_ = + std::unique_ptr>(new UnionFind(all_extents_)); + + for (auto expr : fusion_.unordered_exprs()) { + // Expressions fall into one of the following categories: + // Broadcast + // Reduction + // Reshape + // Permute + // Pointwise + if (expr->isA()) { + std::cout << "Broadcast op: " << expr->toString() << std::endl; + auto bop = reinterpret_cast(expr); + auto bcast_flags = bop->getBroadcastDimFlags(); + auto inp = reinterpret_cast(bop->in()); + auto outp = reinterpret_cast(bop->out()); + auto indom = inp->domain()->noReductions(); + auto outdom = outp->domain()->noReductions(); + for (size_t i = 0, j = 0; i < bcast_flags.size(); ++i) { + if (bcast_flags[i]) { + // This output dim is a new bcast dimension + continue; + } + auto idin = indom[j++]; + auto idout = outdom[i]; + id_partition_->mergeSetsFromValues(idin, idout); + extent_partition_->mergeSetsFromValues(idin->extent(), idout->extent()); + } + } else if (expr->isA()) { + std::cout << "Reduction op: " << expr->toString() << std::endl; + auto rop = reinterpret_cast(expr); + // Instead of flags, we just look at the output type to find reduction + // IDs, which will tell us which axes are being reduced over + auto inp = reinterpret_cast(rop->in()); + auto outp = reinterpret_cast(rop->out()); + auto indom = inp->domain()->domain(); + auto outdom = outp->domain()->domain(); + for (size_t i = 0; i < indom.size(); ++i) { + auto idin = indom[i]; + auto idout = outdom[i]; + if (idout->isReduction()) { + // Don't merge serial input with rdomain output + // TODO: set REDUCES relation + id_relations_.push_back(Relation(RelationType::Reduces, idin, idout)); + continue; + } + if (idin->isBroadcast()) { + id_relations_.push_back( + Relation(RelationType::ResolvesBroadcast, idin, idout)); + continue; + } + id_partition_->mergeSetsFromValues(idin, idout); + extent_partition_->mergeSetsFromValues(idin->extent(), idout->extent()); + } + } else if (expr->isA()) { + auto top = reinterpret_cast(expr); + auto inp = reinterpret_cast(top->in()); + auto outp = reinterpret_cast(top->out()); + auto indom = inp->domain()->domain(); + auto outdom = outp->domain()->domain(); + auto n2o = top->new2old(); + for (size_t i = 0; i < outdom.size(); ++i) { + auto oldi = n2o[i]; + auto idin = indom[oldi]; + auto idout = outdom[i]; + id_partition_->mergeSetsFromValues(idin, idout); + extent_partition_->mergeSetsFromValues(idin->extent(), idout->extent()); + } + } else if (expr->isA()) { + std::cout << "Skipping reshape op: " << expr->toString() << std::endl; + //} else if (expr->isA()) { // pending Naoya's recent work + // TODO: Work through the list of other ops: + /* + *FullOp # nothing to do + *ARangeOp # nothing to do + *EyeOp # nothing to do + *UnaryOp # handled by else + *BinaryOp # handled by else + *TernaryOp # handled by else + SelectOp + IndexSelectOp + TorchGatherOp + RNGOp + *ReductionOp + GroupedReductionOp + WelfordOp + GroupedWelfordOp + LoadStoreOp + MmaOp + *BroadcastOp + SqueezeOp + *TransposeOp + ExpandOp + ShiftOp + GatherOp + ViewAsScalar + ViewOp + Split + Merge + Swizzle2D + */ + } else { + // For pointwise Exprs (or ExpandOp), this is most clear: + // - We simply merge matching (same position, ignoring reduction + // domains) serial IterDomain classes with one another. + // - If the input IterDomain is a bcast and the output Iterdomain is a + // bcast, we merge their e-classes. + // - If the input is a bcast and the output is serial (resolution of + // the broadcast), then we do not merge. + // Instead, in this case we add a relation that the e-class of the + // output ID _resolves_ the e-class of the input ID. + for (auto inp_val : expr->inputs()) { + if (inp_val->getValType() != ValType::TensorView) { + continue; + } + auto inp = (TensorView*)inp_val; + auto indom = inp->domain()->noReductions(); + for (auto outp_val : expr->outputs()) { + if (outp_val->getValType() != ValType::TensorView) { + continue; + } + auto outp = (TensorView*)outp_val; + // Reduction domains aren't preserved through pointwise ops, so ignore + // them For each non-reduction ID in inp and outp + auto outdom = outp->domain()->noReductions(); + TORCH_CHECK( + indom.size() == outdom.size(), + "Input and output noReductions domains must have equal length in pointwise op"); + for (auto idin = indom.begin(), idout = outdom.begin(); + idin != indom.end() && idout != outdom.end(); + idin++, idout++) { + TORCH_CHECK( + !(*idout)->isBroadcast(), + "Output IterDomains of pointwise ops should not be of broadcast type"); + if ((*idin)->isBroadcast()) { + id_relations_.push_back( + Relation(RelationType::ResolvesBroadcast, *idin, *idout)); + continue; + } + id_partition_->mergeSetsFromValues(*idin, *idout); + extent_partition_->mergeSetsFromValues( + (*idin)->extent(), (*idout)->extent()); + } + } + } + } + } + std::cout << "Equivalence classes of IterDomains:" << std::endl; + for (auto s : id_partition_->getSets()) { + std::cout << " c"; + std::cout << id_partition_->findSetFromValue(s[0]) << ": "; + for (auto id : s) { + std::cout << id->toString() << ", "; + } + std::cout << std::endl; + } + std::cout << "Equivalence classes of extents:" << std::endl; + for (auto s : extent_partition_->getSets()) { + std::cout << " e"; + std::cout << extent_partition_->findSetFromValue(s[0]) << ": "; + for (auto e : s) { + std::cout << e->toString() << ", "; + } + std::cout << std::endl; + } +} + +//! Print out a diagram in GraphViz's .dot format +void IterDomainEGraph::printDot(std::ostream& stream) { + stream << "digraph id_graph {" << std::endl; + // print inputs + stream << " { rank = source;" << std::endl; + // Each input tensor is printed with a partitioned box indicating the ID + // name (not the class label), and edges are drawn from those to their + // associated classes. + for (auto in_val : fusion_.inputs()) { + if (in_val->getValType() != ValType::TensorView) { + continue; + } + auto tv = reinterpret_cast(in_val); + // Example line + // T0 [shape=record,label=""]; + stream << " T" << tv->name() << " [shape=record, label=\"{T" + << tv->name() << "|{"; + bool first = true; + int i = 0; + for (auto id : tv->domain()->domain()) { + if (!first) { + stream << "|"; // separator + } + stream << " " << id->getIterType() + << id->getParallelType() << id->name(); + first = false; + }; + stream << "}}\"];" << std::endl; + } + stream << " }" << std::endl; // rank = min + + // Place edges from inputs to classes + for (auto in_val : fusion_.inputs()) { + if (in_val->getValType() != ValType::TensorView) { + continue; + } + auto tv = reinterpret_cast(in_val); + auto i = 0; + for (auto id : tv->domain()->domain()) { + auto c = id_partition_->findSetFromValue(id); + stream << " T" << tv->name() << ":id" << i++ << " -> "; + stream << "c" << c << ";" << std::endl; + }; + } + + for (auto c : id_partition_.get()->getSetIndices()) { + stream << " c" << c << ";" << std::endl; + } + + // Print all relations between classes + for (auto r : id_relations_) { + auto left_class = id_partition_->findSetFromValue(r.getLeft()); + auto right_class = id_partition_->findSetFromValue(r.getRight()); + // Place a labeled edge for every relation between ID classes + stream << " c" << left_class << " -> c" << right_class; + stream << " [label=\"" << r.typeString() << "\"];" << std::endl; + } + + // Place edges from classes to outputs + for (auto out_val : fusion_.outputs()) { + if (out_val->getValType() != ValType::TensorView) { + continue; + } + auto tv = reinterpret_cast(out_val); + auto i = 0; + for (auto id : tv->domain()->domain()) { + auto c = id_partition_->findSetFromValue(id); + stream << " c" << c << " -> T"; + stream << tv->name() << ":id" << i++ << ";" << std::endl; + }; + } + + // print outputs + stream << " { rank = max;" << std::endl; + // Each output tensor is printed with a partitioned box indicating the ID + // name (not the class label) ON TOP, and edges are drawn to those from + // their associated classes. + for (auto out_val : fusion_.outputs()) { + if (out_val->getValType() != ValType::TensorView) { + continue; + } + auto tv = reinterpret_cast(out_val); + // Example line + // T0 [shape=record,label=""]; + stream << " T" << tv->name() << " [shape=record, label=\"{{"; + bool first = true; + int i = 0; + for (auto id : tv->domain()->domain()) { + if (!first) { + stream << "|"; // separator + } + stream << " " << id->getIterType() + << id->getParallelType() << id->name(); + first = false; + }; + stream << "}|T" << tv->name() << "}\"];" << std::endl; + } + stream << " }" << std::endl; // rank = max + + // Add edges for relations + + stream << "}" << std::endl; // digraph id_graph { +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/third_party/nvfuser/csrc/id_e_graph.h b/third_party/nvfuser/csrc/id_e_graph.h new file mode 100644 index 000000000000..8a241343295c --- /dev/null +++ b/third_party/nvfuser/csrc/id_e_graph.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +enum RelationType { ResolvesBroadcast, Reduces }; + +std::string printRelationType(RelationType r) { + switch (r) { + case RelationType::ResolvesBroadcast: + return "resolvesBcast"; + case RelationType::Reduces: + return "reduces"; + } + return ""; +} + +class Relation { + public: + Relation(RelationType type, IterDomain* left, IterDomain* right) + : type_(type), left_(left), right_(right){}; + + IterDomain* getLeft() const { + return left_; + } + IterDomain* getRight() const { + return right_; + } + + RelationType type() const { + return type_; + } + + std::string typeString() const { + return printRelationType(type()); + } + + protected: + RelationType type_; + IterDomain *left_, *right_; +}; + +//! Implements an E-graph whose "terms" are IterDomains. +class TORCH_CUDA_CU_API IterDomainEGraph { + public: + IterDomainEGraph(Fusion& fusion) : fusion_(fusion) { + initGraph(); + }; + + void initGraph(); + + //! Print out a diagram as a hierarchical graph in GraphViz's .dot format + void printDot(std::ostream& stream = std::cout); + + private: + Fusion& fusion_; + std::vector e_class_ids_; + std::vector id_relations_; + std::vector all_ids_; + std::unordered_set + all_extents_; // Also track extents to find which are equivalent + std::unique_ptr> id_partition_; + std::unique_ptr> extent_partition_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/third_party/nvfuser/csrc/union_find.h b/third_party/nvfuser/csrc/union_find.h new file mode 100644 index 000000000000..99b6e79ce11f --- /dev/null +++ b/third_party/nvfuser/csrc/union_find.h @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! A tree-based union-find (aka disjoint-set) data structure using ! subtree +//! sizes instead of ranks. +//! cf. https://en.wikipedia.org/wiki/Disjoint-set_data_structure +template +class UnionFind { + public: + UnionFind(size_t size) { + value_.resize(size); + parent_.resize(size); + size_.resize(size); + // Initialize with all singletoons + for (size_t i = 0; i < size; ++i) { + parent_[i] = i; + size_[i] = 1; + } + } + + UnionFind(std::vector vals) : UnionFind(vals.size()) { + for (size_t i = 0; i < vals.size(); ++i) { + this->setValue(i, vals[i]); + } + } + + UnionFind(std::unordered_set vals) : UnionFind(vals.size()) { + size_t i = 0; + for (auto v : vals) { + this->setValue(i++, v); + } + } + + void setValue(int pos, const T& val) { + value_[pos] = val; + val_to_pos_[val] = pos; + } + T getValue(int pos) { + TORCH_CHECK( + pos < value_.size(), + "Passed invalid position ", + pos, + " for UnionFind with ", + value_.size(), + " entries"); + return value_[pos]; + } + + //! Insert the given value and return the new number of elements + size_t insertValue(const T& val) { + auto pos = parent_.size(); + parent_.push_back(pos); + size_.push_back(1); + val_to_pos_[val] = pos; + value_.push_back(val); + return pos + 1; + } + + //! Find the integer position of val + size_t getPosition(const T& val) { + return val_to_pos_.at(val); + } + + //! Get the integer index of the set from given position + size_t findSet(size_t v) { + if (v == parent_[v]) + return v; + // Note that this step actually updates the tree to point directly to the + // root index, meaning subsequent look-ups will not need to recurse. + return parent_[v] = findSet(parent_[v]); + } + //! Get the integer index of the set for a given value + size_t findSetFromValue(T val) { + return findSet(getPosition(val)); + } + + //! Get all elements in the set with given index (up to O(n^2)) + std::vector getSet(size_t idx) { + std::vector s; + for (size_t i = 0; i < parent_.size(); ++i) { + if (findSet(i) == idx) { + s.push_back(value_.at(i)); + } + } + return s; + } + + //! Get a vector of all sets of values + std::vector> getSets() { + std::vector> out; + for (size_t i = 0; i < parent_.size(); ++i) { + auto s = getSet(i); + if (s.size() > 0) { + out.push_back(s); + } + } + return out; + } + + //! Get a vector of set indexes + std::vector getSetIndices() { + std::vector ids; + for (size_t i = 0; i < parent_.size(); ++i) { + if (parent_[i] == i) { + ids.push_back(i); + } + } + return ids; + } + + //! Merge two sets in the partition + void mergeSets(size_t a, size_t b) { + if (a != b) { + if (size_[a] < size_[b]) + std::swap(a, b); + parent_[b] = a; + size_[a] += size_[b]; + } + } + //! Merge the sets containing two given values + void mergeSetsFromValues(T val_a, T val_b) { + auto a = findSet(getPosition(val_a)); + auto b = findSet(getPosition(val_b)); + mergeSets(a, b); + } + + private: + std::vector value_; + std::unordered_map val_to_pos_; + std::vector parent_; + std::vector size_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/third_party/nvfuser/test/test_gpu_id_e_graph.cpp b/third_party/nvfuser/test/test_gpu_id_e_graph.cpp new file mode 100644 index 000000000000..a06bca470f12 --- /dev/null +++ b/third_party/nvfuser/test/test_gpu_id_e_graph.cpp @@ -0,0 +1,161 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; + +// Play with forming equivalence classes of IterDomains +TEST_F(NVFuserTest, FusionIDEGraph) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + // auto tv0 = makeSymbolicTensor(1); + auto tv0 = makeConcreteTensor({5}); + fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // [w, y] + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + // [w] + auto tv4 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv5 = add(tv4, tv2); + // [w, x] + fusion.addOutput(tv5); + + // [w] + auto tv6 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv7 = add(tv6, tv2); + // [y] + fusion.addOutput(tv7); + + auto tv8 = sum(tv1, {0}); + auto tv9 = add(tv8, tv0); + fusion.addOutput(tv9); + + fusion.printMath(); + // fusion.print(); + + IterDomainEGraph eg(fusion); + + eg.printDot(); +} + +// Very simple graph with a broadcast and no reductions +TEST_F(NVFuserTest, FusionSimpleMulIDGraph) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({2}); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = mul(tv0, tv2); + + fusion.addOutput(tv3); + + fusion.printMath(); + // fusion.print(); + + IterDomainEGraph eg(fusion); + + eg.printDot(); +} + +// Simple reshape example +TEST_F(NVFuserTest, FusionReshapeIDGraph) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 5, 12}); + fusion.addInput(tv0); + + auto tv1 = view(tv0, {2, 3, 5, 12}, {6, 5, 4, 3}); + + fusion.addOutput(tv1); + + fusion.printMath(); + // fusion.print(); + + IterDomainEGraph eg(fusion); + + eg.printDot(); +} + +// Gram matrix (inner product matrix) example +TEST_F(NVFuserTest, FusionGramMatrixIdGraph) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [n, d] + auto tv0 = makeConcreteTensor({5, 7}); + fusion.addInput(tv0); + + // [1, n, d] + auto tv1 = broadcast(tv0, {true, false, false}); + // [n, 1, d] + auto tv2 = broadcast(tv0, {false, true, false}); + + // [n, n, d] + auto tv3 = mul(tv1, tv2); + + // [n, n] + auto tv4 = sum(tv3, {2}); + + fusion.addOutput(tv4); + + fusion.printMath(); + // fusion.print(); + + IterDomainEGraph eg(fusion); + + eg.printDot(); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({5, 7}, options); + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA)