Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,11 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
IndexStmt divide(IndexVar i, IndexVar i1, IndexVar i2, size_t divideFactor) const; // TODO: TailStrategy


/// The loopfuse transformation fuses common outer loops in
/// 2 iteration graphs.
IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector<int>& path) const;


/// The reorder transformation swaps two directly nested index
/// variables in an iteration graph. This changes the order of
/// iteration through the space and the order of tensor accesses.
Expand Down
24 changes: 24 additions & 0 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class IndexStmt;
class TransformationInterface;
class Reorder;
class Precompute;
class LoopFuse;
class ForAllReplace;
class AddSuchThatPredicates;
class Parallelize;
Expand All @@ -32,6 +33,7 @@ class Transformation {
public:
Transformation(Reorder);
Transformation(Precompute);
Transformation(LoopFuse);
Transformation(ForAllReplace);
Transformation(Parallelize);
Transformation(TopoReorder);
Expand Down Expand Up @@ -114,6 +116,28 @@ class Precompute : public TransformationInterface {
/// Print a precompute command.
std::ostream &operator<<(std::ostream &, const Precompute &);

/// The loopfuse optimization rewrite an index expression to precompute
/// part of the `expr` and store it to a workspace.
class LoopFuse : public TransformationInterface {
public:
LoopFuse();
LoopFuse(int pos, bool isProducerOnLeft, std::vector<int>& path);

int getPos() const;
bool getIsProducerOnLeft() const;
std::vector<int>& getPath() const;

/// Apply the loopfuse optimization to a concrete index statement.
IndexStmt apply(IndexStmt, std::string *reason = nullptr) const;

void print(std::ostream &os) const;

private:
struct Content;
std::shared_ptr<Content> content;
};

std::ostream &operator<<(std::ostream &, const LoopFuse &);

/// Replaces all occurrences of directly nested forall nodes of pattern with
/// directly nested loops of replacement
Expand Down
1 change: 1 addition & 0 deletions include/taco/parser/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class Token {
sub,
mul,
div,
colon,
eq,
eot, // End of tokens
error
Expand Down
21 changes: 21 additions & 0 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,26 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
return transformed;
}

IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector<int>& path) const {

std::cout << "Loop fuse pos: " << pos;
std::cout << ", Loop fuse isProducerOnLeft: " << isProducerOnLeft;
for (const auto& p : path) {
std::cout << " " << p;
}
std::cout << std::endl;

string reason;
IndexStmt transformed = *this;
transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason);
if (!transformed.defined()) {
taco_uerror << reason;
}
return transformed;

return *this;
}

IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
std::vector<IndexVar> iw_vars, TensorVar workspace) const {

Expand Down Expand Up @@ -2048,6 +2068,7 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
return transformed;
}


IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
if (accelIndexVars.size() == 0) {
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
Expand Down
Loading