Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions third_party/nvfuser/csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ void FusionExecutor::compileFusion(
}
}

if (isDebugDumpEnabled(DebugDumpOption::FusionDebug)) {
fusion->printDebug();
}

if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
fusion->print();
} else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
Expand Down
8 changes: 6 additions & 2 deletions third_party/nvfuser/csrc/expr_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ class FlattenedAssocCommOp : public Expr {
return other_inputs.empty();
}

std::string toString(int indent_size = 0) const override {
std::string toString(
int indent_size = 0,
SerializationFormat fmt = SerializationFormat::Default) const override {
std::stringstream ss;
indent(ss, indent_size) << getOpString() << "(";
bool needs_comma = false;
Expand All @@ -426,7 +428,9 @@ class FlattenedAssocCommOp : public Expr {
return ss.str();
}

std::string toInlineString(int = 0) const override {
std::string toInlineString(
int = 0,
SerializationFormat fmt = SerializationFormat::Default) const override {
std::stringstream ss;
ss << getOpString() << "(";
bool needs_comma = false;
Expand Down
234 changes: 197 additions & 37 deletions third_party/nvfuser/csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <kernel.h>
#include <lower2device.h>
#include <lower_bank_conflict.h>
#include <utils.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -344,26 +345,150 @@ void Fusion::validateInputs() {
}
}

void Fusion::print() {
void Fusion::serialize(std::ostream& out, SerializationFormat fmt) {
FUSER_PERF_SCOPE("Fusion::serialize");

switch (fmt) {
case SerializationFormat::NameOnly:
out << "Fusion";
break;
case SerializationFormat::Default: {
FusionGuard fg(this);
out << "\n%kernel {\n";
IrMathPrinter op_exprs(out);
op_exprs.handle(this);
out << "\nTransformPrinter : \n";
IrTransformPrinter t_exprs(out);
t_exprs.handle(this);
out << "}\n\n";
break;
}
case SerializationFormat::Debug: {
break;
}
case SerializationFormat::EndOfOption:
break;
}
}

void Fusion::printDebug(std::ostream& out) {
FUSER_PERF_SCOPE("Fusion::printDebug");

out << "Fusion DEBUG INFO {";
std::vector<Val*> inputs_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This vector aliases the fusion's inputs_ vector so it always shows that the fusion doesn't have any inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we just need to delete this line.

out << "\n inputs_ = {";
for (auto& it : inputs_) {
out << "\n " << it->toString(0, SerializationFormat::Debug);
}
out << " }\n";
out << "\n outputs_ = {";
for (auto& it : outputs_) {
out << "\n " << it->toString(0, SerializationFormat::Debug);
auto a = getOutputAlias(it);
if (a != nullptr) {
out << " ALIASES " << a->toString(0, SerializationFormat::NameOnly);
}
}
out << " }\n";
out << "\n all_tv_uses_valid_ = " << all_tv_uses_valid_;
out << "\n is_during_update_uses_ = " << is_during_update_uses_;
out << "\n io_alias_ = {";
for (auto& it : io_alias_) { // NOTE: ordering arbitrary
out << "\n " << it.first->toString(0, SerializationFormat::Debug)
<< " => " << it.second->toString(0, SerializationFormat::Debug);
}
out << " }\n";
out << "\n permuted_input_map_ = {";
for (auto& it : permuted_input_map_) { // NOTE: ordering arbitrary
out << "\n " << it.first << " => " << it.second;
}
out << " }\n";
out << "\n permuted_output_map_ = {";
for (auto& it : permuted_input_map_) { // NOTE: ordering arbitrary
out << "\n " << it.first << " => " << it.second;
}
out << " }\n";

auto ind = " ";
out << ind << " expr_name_counter = " << expr_name_counter_;
out << ind << "\n vals_ (" << vals_.size() << ") = [";
std::vector<std::string> valstrs;
std::vector<std::tuple<int, std::string, std::string>> all_logs;
for (auto& it : vals_) { // NOTE: ordering arbitrary
std::stringstream obss;
obss << it->toString(3, SerializationFormat::Debug);
auto log = it->getLogMessages();
if (log.size() > 0) {
auto val_name = it->toString(0, SerializationFormat::NameOnly);
for (auto num_msg :
it->getLogMessages()) { // pair log sequence number & message
all_logs.push_back({num_msg.first, val_name, num_msg.second});
}
}
valstrs.push_back(obss.str());
}
std::sort(valstrs.begin(), valstrs.end());
for (auto& it : valstrs) { // sorted
out << ind << "\n " << it;
}
out << ind << " ]\n";
out << ind << "\n exprs_ (" << exprs_.size() << ") = [\n";
std::vector<std::string> expstrs;
for (auto& it : exprs_) { // NOTE: ordering arbitrary
expstrs.push_back(it->toString(3, SerializationFormat::NameOnly));
}
std::sort(expstrs.begin(), expstrs.end());
for (auto& it : expstrs) { // sorted
out << ind << it;
}
out << ind << " ]\n";
out << ind << "\n val_type_name_map_ (" << val_type_name_map_.size()
<< ") = {";
for (auto& it : val_type_name_map_) { // NOTE: ordering arbitrary
out << ind << "\n " << (int)it.first << " => " << it.second;
}
out << ind << " }\n";
std::sort(all_logs.begin(), all_logs.end());
out << ind << "\n Logged operations:";
#ifdef NDEBUG
std::cerr << "WARNING: Fusion operations are only logged in Debug builds."
<< std::endl;
#endif
// Actual lognum may be large if there are multiple Fusions defined in this
// process. Instead, just print a local counter for the log messages
// appearing in this Fusion's vals_.
int local_lognum = 0;
for (auto entry : all_logs) {
int lognum;
std::string val_name;
std::string msg;
std::tie(lognum, val_name, msg) = entry;
out << ind << "\n " << local_lognum++ << ") " << val_name << " : "
<< msg;
}
out << "\n}\n";
}

void Fusion::print(std::ostream& out, SerializationFormat fmt) {
FUSER_PERF_SCOPE("Fusion::print");

FusionGuard fg(this);
std::cout << "\n%kernel {\n";
IrMathPrinter op_exprs(std::cout);
out << "\n%kernel {\n";
IrMathPrinter op_exprs(out, fmt);
op_exprs.handle(this);
std::cout << "\nTransformPrinter : \n";
IrTransformPrinter t_exprs(std::cout);
out << "\nTransformPrinter : \n";
IrTransformPrinter t_exprs(out, fmt);
t_exprs.handle(this);
std::cout << "}\n\n";
out << "}\n\n";
}

void Fusion::printKernel(DataType index_type) {
void Fusion::printKernel(DataType index_type, std::ostream& out) {
FUSER_PERF_SCOPE("Fusion::printKernel");
TORCH_INTERNAL_ASSERT(
!this->isA<kir::Kernel>(),
"Cannot \"print kernel\" of a kernel container. ",
"This would require lowering during lowering.");
std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel());
out << codegen::generateCudaKernel(GpuLower(this, index_type).kernel());
}

std::unordered_map<std::string, std::pair<int, int>> Fusion::bankConflictInfo(
Expand All @@ -380,38 +505,54 @@ std::unordered_map<std::string, std::pair<int, int>> Fusion::bankConflictInfo(
return result;
}

void Fusion::printMath(bool from_outputs_only) {
void Fusion::printMath(
bool from_outputs_only,
std::ostream& out,
SerializationFormat fmt) {
FUSER_PERF_SCOPE("Fusion::printMath");

FusionGuard fg(this);
auto exprs_for_print = exprs();
std::cout << "Inputs:" << std::endl;
for (auto inp : inputs()) {
std::cout << " " << inp << ", " << inp->getDataType().value() << std::endl;
}
switch (fmt) {
case SerializationFormat::NameOnly:
out << "Fusion Math";
break;
case SerializationFormat::Default: {
FusionGuard fg(this);
auto exprs_for_print = exprs();
out << "Inputs:" << std::endl;
for (auto inp : inputs()) {
out << " " << inp << ", " << inp->getDataType().value() << std::endl;
}

std::cout << "Outputs:" << std::endl;
for (auto out : outputs()) {
std::cout << " " << out << ", " << out->getDataType().value() << std::endl;
}
out << "Outputs:" << std::endl;
for (auto output : outputs()) {
out << " " << output << ", " << output->getDataType().value()
<< std::endl;
}

// If we want everything in the fusion, grab all values without uses to
// traverse from.
if (!from_outputs_only) {
std::vector<Val*> leaf_vals;
for (auto val : deterministic_vals()) {
if (val->uses().empty()) {
leaf_vals.push_back(val);
// If we want everything in the fusion, grab all values without uses to
// traverse from.
if (!from_outputs_only) {
std::vector<Val*> leaf_vals;
for (auto val : deterministic_vals()) {
if (val->uses().empty()) {
leaf_vals.push_back(val);
}
}
exprs_for_print = StmtSort::getExprs(this, leaf_vals);
}
}
exprs_for_print = StmtSort::getExprs(this, leaf_vals);
}

std::cout << "\n%kernel_math {\n";
for (auto expr : exprs_for_print) {
std::cout << expr;
out << "\n%kernel_math {\n";
for (auto expr : exprs_for_print) {
out << expr;
}
out << "}\n\n";
break;
}
case SerializationFormat::Debug:
break;
case SerializationFormat::EndOfOption:
break;
}
std::cout << "}\n\n";
}

std::vector<Val*> Fusion::inputsAndCreated() {
Expand All @@ -427,12 +568,25 @@ std::vector<Val*> Fusion::inputsAndCreated() {
return result;
}

void Fusion::printTransforms() {
void Fusion::printTransforms(std::ostream& out, SerializationFormat fmt) {
FUSER_PERF_SCOPE("Fusion::printTransforms");

FusionGuard fg(this);
IrTransformPrinter t_exprs(std::cout);
t_exprs.handle(this);
switch (fmt) {
case SerializationFormat::NameOnly:
out << "Fusion Transforms";
break;
case SerializationFormat::Default: {
FusionGuard fg(this);
IrTransformPrinter t_exprs(out);
t_exprs.handle(this);
break;
}
case SerializationFormat::Debug: {
out << "DEBUG OUTPUT:" << std::endl;
}
case SerializationFormat::EndOfOption:
break;
}
}

void Fusion::registerVal(Val* val) {
Expand Down Expand Up @@ -660,6 +814,12 @@ bool Fusion::isAliasCompatible(Val* left, Val* right) {
}

void Fusion::aliasOutputToInput(Val* output, Val* input) {
VAL_LOG_EXPLICIT(
output,
"Fusion::aliasOutputToInput",
output->toString(0, SerializationFormat::NameOnly),
input->toString(0, SerializationFormat::NameOnly), );

// Because we could cast output when input is cast.
TORCH_INTERNAL_ASSERT(
!output->isFusionOutput(),
Expand Down
26 changes: 22 additions & 4 deletions third_party/nvfuser/csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,36 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! Assert that all leaves found from outputs are registered as an input
void validateInputs();

//! Serialize in text or binary form using one of many formats
void serialize(std::ostream& out, SerializationFormat fmt);

//! Deserialize from the given format
void deserialize(std::istream& in, SerializationFormat fmt);

//! Print detailed debug information about this fusion to the console
void printDebug(std::ostream& out = std::cout);

//! Print this fusion to the console
void print();
void print(
std::ostream& out = std::cout,
SerializationFormat fmt = SerializationFormat::Default);

//! Print Arith exprs
//! \param from_outputs_only Only print exprs reachable from outputs
void printMath(bool from_outputs_only = true);
void printMath(
bool from_outputs_only = true,
std::ostream& out = std::cout,
SerializationFormat fmt = SerializationFormat::Default);

//! Print transformations used in fusion (can be very verbose)
void printTransforms();
void printTransforms(
std::ostream& out = std::cout,
SerializationFormat fmt = SerializationFormat::Default);

//! Lower the fusion and print a kernel
void printKernel(DataType index_type = DataType::Int);
void printKernel(
DataType index_type = DataType::Int,
std::ostream& out = std::cout);

//! Returns if this fusion is noop, for example, trivially forwarding inputs,
//! or all outputs are size-0 tensors, etc.
Expand Down
6 changes: 4 additions & 2 deletions third_party/nvfuser/csrc/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ bool Statement::lessThan(const Statement* stmt1, const Statement* stmt2) {
return stmt1->name() < stmt2->name();
}

std::string Statement::toString(int indent_size) const {
std::string Statement::toString(int indent_size, SerializationFormat fmt)
const {
TORCH_INTERNAL_ASSERT(
false, "toString for IR node ", typeid(*this).name(), " is not defined");
}

std::string Statement::toInlineString(int indent_size) const {
std::string Statement::toInlineString(int indent_size, SerializationFormat fmt)
const {
TORCH_INTERNAL_ASSERT(
false,
"toInlineString for IR node ",
Expand Down
Loading