Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 84 additions & 79 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp

Large diffs are not rendered by default.

90 changes: 50 additions & 40 deletions torch/csrc/jit/codegen/cuda/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,37 +63,20 @@ class TORCH_CUDA_CU_API IterDomainGraph {
public:
IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false);

const DisjointSets<IterDomain*>& permissiveNodes() const {
return permissive_nodes_;
}
const DisjointSets<IterDomain*>& exactNodes() const {
return exact_nodes_;
}
const DisjointSets<IterDomain*>& almostExactNodes() const {
return almost_exact_nodes_;
}
const DisjointSets<IterDomain*>& loopNodes() const {
return loop_nodes_;
}

// Returns the disjoint set according to one of the mapping mode types.
const DisjointSets<IterDomain*>& getNodes(IdMappingMode mode) const;
// Consumers and producers is not symmetric like the other sets
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
consumers() const {
return consumers_;
}

const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
producers() const {
return producers_;
}

const DisjointSets<IterDomain*>& siblings() const {
return sibling_sets_;
}

const VectorOfUniqueEntries<IterDomain*>& allIds() const {
return all_ids_;
}

// TODO: Seems a bit unfortunate that this isn't IterDomain local information.
const std::unordered_set<IterDomain*>& viewRfactorIds() const {
return view_rfactor_ids_;
}
Expand All @@ -102,12 +85,11 @@ class TORCH_CUDA_CU_API IterDomainGraph {
// id_map have matching inputs (if forward), or outputs (if not forward).
// Returning true means the expressions are "the same", in terms they modify
// matching original extents, by the same amount.
static bool exprsMap(
Expr* first,
Expr* second,
bool forward,
const DisjointSets<IterDomain*>& id_map);
bool exprsMap(Expr* first, Expr* second, bool forward, IdMappingMode mode)
const;

// Returns if a self mapping was detected that would invalidate assumptions of
// the overall lowering system.
bool hasSelfMapping() const {
return self_mapping_info_.has_value();
}
Expand All @@ -118,29 +100,51 @@ class TORCH_CUDA_CU_API IterDomainGraph {
private:
void build(Fusion* fusion);

// Non-const internal only version of getNodes.
DisjointSets<IterDomain*>& nodes(IdMappingMode mode);

// Simple alias
void mapNodes(IterDomain* id0, IterDomain* id1, IdMappingMode mode) {
nodes(mode).mapEntries(id0, id1);
}

void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id);

// Checks if exprsMap then if forward will map outputs else inputs in exact
// and permissive map.
void mapThroughExpr(Expr* first, Expr* second, bool forward);
// Checks if expr's are considered "the same" where sameness inputs and
// outputs in the same position across expressions map with provided
// MappingMode. If the expressions are determined the same then
// if forward
// will map outputs
// else
// will map inputs
// in the provided mode
void mapThroughExpr(
Expr* first,
Expr* second,
bool forward,
IdMappingMode mode);

DisjointSets<IterDomain*> permissive_nodes_;
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> almost_exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;
// Keeps a disjoint set entry for all IterDomain mapping mode types.
//
// Using an array here might be nice, but it seems hard to use an enum as an
// array key
// https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum
std::unordered_map<IdMappingMode, DisjointSets<IterDomain*>> nodes_;

// Consumers and producers is not symmetric like the other sets
// TODO: Generalize to mapping type. Mappings between producer TV ids and
// consumer TV ids depend on the mapping type.
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
producers_;

DisjointSets<IterDomain*> sibling_sets_;

VectorOfUniqueEntries<IterDomain*> all_ids_;

// Hold a set of iter domains that are considered view rfactor ids. This
// identification is particularly important to understand if split operations
// are divisible or not.
std::unordered_set<IterDomain*> view_rfactor_ids_;

// Debug information to hold if a self mapping in a TensorView is found.
c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
self_mapping_info_ = c10::nullopt;
};
Expand All @@ -159,6 +163,8 @@ class TORCH_CUDA_CU_API ComputeAtMap {
//! Run through disjoint sets in the LOOP map, make sure there's only one
//! non-serial parallel type in each disjoint set, set the parallel type of
//! all IterDomains in the disjoint set to that PType.
//!
//! TODO: Should this be moved to parallel validation?
void validateAndPropagatePType();

//! Run through disjoint sets in the LOOP map and allocate the index
Expand All @@ -178,11 +184,15 @@ class TORCH_CUDA_CU_API ComputeAtMap {
//! Under this condition, we can pre-allocate all required index
//! variable integers before creating any kir::forloop, and this
//! would help optimizing the generated integer math for indexing.
//!
//! TODO: Should this be moved to an indexing map structure outside of
//! ComputeAtMap that has a ComputeAtMap reference?
void allocateIndexVariables();

//! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode
bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const;

//! Simple alias to IdGraph mappings.
bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const {
return idGraph().getNodes(mode).strictAreMapped(id0, id1);
}
//! Returns an iter domain that is the maximum expanded size of all iter
//! domains the one provided maps to. Useful for opening loops to the correct
//! iteration size. Not guarenteed to return the same ID every call, but is
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ std::unordered_set<Split*> getAllDivisibleSplits(
auto concrete_id = entry.first;
auto original_view_split = entry.second;

const auto& exact_mapped_ids =
ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector();
const auto& exact_mapped_ids = ca_map->idGraph()
.getNodes(IdMappingMode::EXACT)
.getDisjointSetOf(concrete_id)
.vector();
for (auto other_id : exact_mapped_ids) {
if (other_id->definition() == nullptr) {
continue;
Expand All @@ -102,11 +104,11 @@ std::unordered_set<Split*> getAllDivisibleSplits(
continue;
}

if (IterDomainGraph::exprsMap(
if (ca_map->idGraph().exprsMap(
original_view_split,
other_id->definition(),
false,
ca_map->idGraph().exactNodes())) {
IdMappingMode::EXACT)) {
all_divisible_splits.emplace(other_id->definition()->as<Split>());
}
}
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,8 +1271,11 @@ namespace {
bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector<Val*>& ids) {
return std::any_of(ids.begin(), ids.end(), [&](Val* val) {
return val->isA<IterDomain>() &&
GpuLower::current()->caMap()->areMapped(
id, val->as<IterDomain>(), IdMappingMode::PERMISSIVE);
GpuLower::current()
->caMap()
->idGraph()
.getNodes(IdMappingMode::PERMISSIVE)
.permissiveAreMapped(id, val->as<IterDomain>());
});
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void HaloInfo::setRootAxisInfo(

HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map)
// Make a copy of the permissive map for extent comparators
: permissive_map_(ca_map->idGraph().permissiveNodes()) {
: permissive_map_(ca_map->idGraph().getNodes(IdMappingMode::PERMISSIVE)) {
const auto vals = fusion->usedMathVals();
auto tvs = ir_utils::filterByType<TensorView>(vals);

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) {
// Mark those as an active use of the rfactor, if two are detected, return
// true.
for (const auto& disjoint_set_shared_ptr :
ca_map.idGraph().exactNodes().disjointSets()) {
ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) {
// Make sure there's at least one rfactor domain in the set, otherwise we
// don't need to check anything from this set.
if (!std::any_of(
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class DomainMap : public pointwise_utils::DomainMap {
const auto& root_dom = tv->getRootDomain();
IterDomain* mapped_id = nullptr;
for (auto i : c10::irange(root_dom.size())) {
if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped(
root_dom[i], root_dim)) {
if (ca_map_.idGraph()
.getNodes(IdMappingMode::EXACT)
.permissiveAreMapped(root_dom[i], root_dim)) {
mapped_id = root_dom[i];
break;
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ void BoundedDirectionalTransformPropagator::bothWays(
DisjointSets<IterDomain*> disjointViewSets(Fusion* fusion) {
// Start from the exact iter domain graph of the fusion
IterDomainGraph id_graph(fusion);
auto disjoint_view_ids = id_graph.exactNodes();
auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT);

// If iter domains are involved in any transformation from root domains to
// rfactor domains they should be considered "contaminated".
Expand Down Expand Up @@ -2240,7 +2240,7 @@ void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {

std::unordered_set<IterDomain*> terminating_rfactor_dims;
for (const auto& disjoint_set_shared_ptr :
ca_map.idGraph().exactNodes().disjointSets()) {
ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) {
if (std::none_of(
disjoint_set_shared_ptr->vector().begin(),
disjoint_set_shared_ptr->vector().end(),
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ namespace {
Val* commonOrConstExtent(
std::shared_ptr<const ComputeAtMap> ca_map,
IterDomain* id) {
auto disjoint_set = ca_map->idGraph().almostExactNodes().getDisjointSetOf(id);
auto disjoint_set = ca_map->idGraph()
.getNodes(IdMappingMode::ALMOSTEXACT)
.getDisjointSetOf(id);
for (auto entry : disjoint_set) {
if (entry->extent()->isConstScalar()) {
return entry->extent();
Expand Down
23 changes: 13 additions & 10 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,19 +1210,22 @@ TEST_F(NVFuserTest, FusionViewIdGraph_CUDA) {

// Start from the exact iter domain graph of the fusion
IterDomainGraph id_graph(&fusion);
auto disjoint_view_ids = id_graph.exactNodes();
auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT);

TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT)
.strictAreMapped(tv2->axis(1), tv4->axis(1)));
TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT)
.strictAreMapped(tv2->axis(2), tv4->axis(2)));

TORCH_CHECK(
id_graph.getNodes(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[1], tv12->getRootDomain()[1]));
TORCH_CHECK(
id_graph.exactNodes().strictAreMapped(tv2->axis(1), tv4->axis(1)));
id_graph.getNodes(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[2], tv12->getRootDomain()[2]));
TORCH_CHECK(
id_graph.exactNodes().strictAreMapped(tv2->axis(2), tv4->axis(2)));

TORCH_CHECK(id_graph.exactNodes().strictAreMapped(
tv2->getRootDomain()[1], tv12->getRootDomain()[1]));
TORCH_CHECK(id_graph.exactNodes().strictAreMapped(
tv2->getRootDomain()[2], tv12->getRootDomain()[2]));
TORCH_CHECK(id_graph.exactNodes().strictAreMapped(
tv2->getRootDomain()[3], tv12->getRootDomain()[3]));
id_graph.getNodes(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[3], tv12->getRootDomain()[3]));
}

TEST_F(NVFuserTest, FusionViewVectorize_CUDA) {
Expand Down