From d823e14854f6fa37255e194e6bd7040d55f1766e Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Fri, 22 Nov 2024 16:08:05 -0600 Subject: [PATCH 01/20] build setup --- .vscode/settings.json | 5 +++ src/enzyme_ad/jax/BUILD | 45 ++++++++++++++++++++++ src/enzyme_ad/jax/Dialects/CommDialect.cpp | 1 + src/enzyme_ad/jax/Dialects/CommDialect.h | 1 + src/enzyme_ad/jax/Dialects/CommDialect.td | 14 +++++++ src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 5 +++ test/lit_tests/unroll.mlir | 2 +- 7 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.cpp create mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.h create mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.td diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..ef4dc2f6a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.associations": { + "*.inc": "cpp" + } +} \ No newline at end of file diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 632d89ad6..00de3de36 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -90,6 +90,35 @@ gentbl_cc_library( tblgen = "@llvm-project//mlir:mlir-tblgen", ) + +td_library( + name = "CommDialectFiles", + srcs = [ + "Dialects/CommDialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles" + ] +) + + +gentbl_cc_library( + name = "CommDialectIncGen", + tbl_outs = [( + ["-gen-dialect-decls"], + "Dialects/CommDialect.h.inc", + ), ( + ["-gen-dialect-defs"], + "Dialects/CommDialect.cpp.inc", + ), + ], + td_file = "Dialects/CommDialect.td", + deps = [ + ":CommDialectFiles", + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", +) + gentbl_cc_library( name = "TransformOpsImplIncGen", tbl_outs = [( @@ -116,10 +145,22 @@ cc_library( "@llvm-project//mlir:TransformDialectInterfaces", ":TransformOpsIncGen", ":TransformOpsImplIncGen", + ":CommDialectIncGen", ":XLADerivatives", ], ) +cc_library( + name = "Dialects", + srcs = glob(["Dialects/*.cpp"]), + hdrs = glob(["Dialects/*.h"]), + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ":CommDialectIncGen", + ], +) + td_library( name = "ImplementationsCommonTdFiles", srcs = [ @@ -228,11 +269,13 @@ cc_library( [ "Implementations/*.cpp", "Passes/*.cpp", + "Dialects/*.cpp", ], ), hdrs = glob([ "Implementations/*.h", "Passes/*.h", + "Dialects/*.h", ]), copts = [ "-Werror=unused-variable", @@ -280,6 +323,7 @@ pybind_library( deps = [ ":XLADerivatives", ":TransformOps", + ":Dialects", # This is similar to xla_binary rule and is needed to make XLA client compile. # "@tsl//tsl/framework:allocator", # "@tsl//tsl/framework:allocator_registry_impl", @@ -373,6 +417,7 @@ pybind_extension( ":clang_compile", ":compile_with_xla", ":TransformOps", + ":Dialects", "@com_google_absl//absl/status:statusor", "@enzyme//:EnzymeMLIR", "@enzyme//:EnzymeStatic", diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/CommDialect.cpp new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/CommDialect.cpp @@ -0,0 +1 @@ + diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.h b/src/enzyme_ad/jax/Dialects/CommDialect.h new file mode 100644 index 000000000..af0355776 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/CommDialect.h @@ -0,0 +1 @@ +#include "src/enzyme_ad/jax/Dialects/CommDialect.h.inc" diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.td b/src/enzyme_ad/jax/Dialects/CommDialect.td new file mode 100644 index 000000000..91243b860 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/CommDialect.td @@ -0,0 +1,14 @@ +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" + +def CommunicationDialect : Dialect { + let name = "comm"; + let summary = "A prototype dialect for various communication ops"; + let description = [{}]; + let cppNamespace = "::mlir::comm"; +} + +class CommOp traits = []> : Op; + +def CommFoo : CommOp<"ComFooStr">; diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 33c0c73c5..86ad1fa5e 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -47,6 +47,10 @@ #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "Dialects/CommDialect.h" + + + using namespace mlir; namespace mlir { @@ -89,6 +93,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); diff --git a/test/lit_tests/unroll.mlir b/test/lit_tests/unroll.mlir index da21e959e..bbe8c98f6 100644 --- a/test/lit_tests/unroll.mlir +++ b/test/lit_tests/unroll.mlir @@ -3,7 +3,7 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - + comm.CommFooStr %start = stablehlo.constant dense<0> : tensor %lim = stablehlo.constant dense<5> : tensor From b255eda62e8ea031304767eb8aca89e343087299 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 4 Dec 2024 18:58:20 -0600 Subject: [PATCH 02/20] Recognize op --- .vscode/settings.json | 62 +++++++++++++++++++++- src/enzyme_ad/jax/BUILD | 10 +++- src/enzyme_ad/jax/Dialects/CommDialect.cpp | 15 ++++++ src/enzyme_ad/jax/Dialects/CommDialect.h | 10 ++++ src/enzyme_ad/jax/Dialects/CommDialect.td | 9 +++- src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp | 32 +++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 7 +++ src/enzyme_ad/jax/Passes/Passes.td | 10 ++++ src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 5 +- test/com_tests/unroll.mlir | 33 ++++++++++++ test/lit_tests/unroll.mlir | 1 - 11 files changed, 187 insertions(+), 7 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp create mode 100644 test/com_tests/unroll.mlir diff --git a/.vscode/settings.json b/.vscode/settings.json index ef4dc2f6a..6387f84fe 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,65 @@ { "files.associations": { - "*.inc": "cpp" + "*.inc": "cpp", + "algorithm": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "string": "cpp", + "cmath": "cpp", + "typeinfo": "cpp", + "cstdlib": "cpp", + "limits": "cpp", + "new": "cpp", + "type_traits": "cpp", + "vector": "cpp", + "__verbose_abort": "cpp", + "array": "cpp", + "cstring": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "execution": "cpp", + "memory": "cpp", + "initializer_list": "cpp", + "iosfwd": "cpp", + "list": "cpp", + "stdexcept": "cpp", + "unordered_map": "cpp", + "variant": "cpp", + "atomic": "cpp", + "bit": "cpp", + "*.tcc": "cpp", + "compare": "cpp", + "concepts": "cpp", + "exception": "cpp", + "functional": "cpp", + "iterator": "cpp", + "memory_resource": "cpp", + "random": "cpp", + "tuple": "cpp", + "utility": "cpp", + "cctype": "cpp", + "clocale": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "ctime": "cpp", + "deque": "cpp", + "map": "cpp", + "set": "cpp", + "istream": "cpp", + "mutex": "cpp", + "numbers": "cpp", + "numeric": "cpp", + "optional": "cpp", + "ostream": "cpp", + "ratio": "cpp", + "semaphore": "cpp", + "shared_mutex": "cpp", + "sstream": "cpp", + "stop_token": "cpp", + "streambuf": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "thread": "cpp" } } \ No newline at end of file diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 00de3de36..8f0f297a5 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -110,11 +110,18 @@ gentbl_cc_library( ), ( ["-gen-dialect-defs"], "Dialects/CommDialect.cpp.inc", + ),( + ["-gen-op-decls"], + "Dialects/CommOps.h.inc", + ), ( + ["-gen-op-defs"], + "Dialects/CommOps.cpp.inc", ), ], td_file = "Dialects/CommDialect.td", deps = [ ":CommDialectFiles", + "@llvm-project//mlir:OpBaseTdFiles", ], tblgen = "@llvm-project//mlir:mlir-tblgen", ) @@ -152,7 +159,7 @@ cc_library( cc_library( name = "Dialects", - srcs = glob(["Dialects/*.cpp"]), + srcs = glob(["Dialects/*.cpp", "Dialects/CommDialect.cpp.inc"]), hdrs = glob(["Dialects/*.h"]), deps = [ "@llvm-project//mlir:IR", @@ -284,6 +291,7 @@ cc_library( "-Werror=unused-result", ], deps = [ + ":Dialects", ":EnzymeXLAPassesIncGen", ":EnzyeHLOPatternsIncGen", ":mhlo-derivatives", diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/CommDialect.cpp index 8b1378917..401bbaa56 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/CommDialect.cpp @@ -1 +1,16 @@ +#include "src/enzyme_ad/jax/Dialects/CommDialect.h" +#include "src/enzyme_ad/jax/Dialects/CommDialect.cpp.inc" +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialects/CommOps.cpp.inc" + +using namespace mlir; +using namespace mlir::comm; + + +void CommunicationDialect::initialize() { + addOperations< + CommFoo>(); // Register CommFoo operation +} + +// CommunicationDialect::CommunicationDialect(mlir::MLIRContext*){} \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.h b/src/enzyme_ad/jax/Dialects/CommDialect.h index af0355776..611c35a34 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.h +++ b/src/enzyme_ad/jax/Dialects/CommDialect.h @@ -1 +1,11 @@ +#include "mlir/IR/Dialect.h" +#include "mlir/Support/TypeID.h" +#include "mlir/include/mlir/IR/DialectImplementation.h" #include "src/enzyme_ad/jax/Dialects/CommDialect.h.inc" + +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialects/CommOps.h.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.td b/src/enzyme_ad/jax/Dialects/CommDialect.td index 91243b860..67141f06b 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/CommDialect.td @@ -11,4 +11,11 @@ def CommunicationDialect : Dialect { class CommOp traits = []> : Op; -def CommFoo : CommOp<"ComFooStr">; +def CommFoo : CommOp<"Foo"> { + let summary = "do-nothing test op"; + let arguments = (ins ); + let results = (outs ); + let assemblyFormat = [{ + attr-dict + }]; +} diff --git a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp new file mode 100644 index 000000000..dfa432280 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp @@ -0,0 +1,32 @@ +#include "src/enzyme_ad/jax/Dialects/CommDialect.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::comm; +using namespace enzyme; using namespace mlir::enzyme; // one of the upstream includes we need is wrapped in this namespace + +namespace { +struct CommRemoveFoo : public CommRemoveFooBase { + + void runOnOperation() override { + mlir::Operation* op = getOperation(); + + // if op is CommFoo, erase + if(isa(op)){ + op->erase(); + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace comm { +std::unique_ptr createCommRemoveFooPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 584ec22ce..08c35b035 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -22,6 +22,9 @@ std::unique_ptr createEnzymeHLOOptPass(); std::unique_ptr createEnzymeHLOUnrollPass(); std::unique_ptr createPrintPass(); } // namespace enzyme +namespace comm { +std::unique_ptr createCommRemoveFooPass(); +} } // namespace mlir namespace mlir { @@ -70,6 +73,10 @@ namespace LLVM { class LLVMDialect; } +namespace comm { + class CommunicationDialect; +} + #define GEN_PASS_REGISTRATION #include "src/enzyme_ad/jax/Passes/Passes.h.inc" diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 2aafc894e..ce1346323 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -11,6 +11,16 @@ include "mlir/Pass/PassBase.td" + +def CommRemoveFoo : Pass<"remove-comm-foo"> { + let summary = "Removes all comm foo operations"; + let dependentDialects = [ + "comm::CommunicationDialect" + ]; + let constructor = "mlir::comm::createCommRemoveFooPass()"; +} + + def ArithRaisingPass : Pass<"arith-raise"> { let summary = "Raise Arith to mhlo"; let dependentDialects = [ diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 86ad1fa5e..820ad921b 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -92,10 +92,9 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); - registry.insert(); - + registry.insert(); registry.insert(); + registry.insert(); mlir::registerenzymePasses(); regsiterenzymeXLAPasses(); diff --git a/test/com_tests/unroll.mlir b/test/com_tests/unroll.mlir new file mode 100644 index 000000000..e04d5cf6c --- /dev/null +++ b/test/com_tests/unroll.mlir @@ -0,0 +1,33 @@ +// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s + +module { + + func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + comm.Foo + %start = stablehlo.constant dense<0> : tensor + + %lim = stablehlo.constant dense<5> : tensor + + %step = stablehlo.constant dense<1> : tensor + + %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + return %w#0 : tensor<2x2xf32> + } +} + +// CHECK: func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.add %0, %0 : tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.add %1, %1 : tensor<2x2xf32> +// CHECK-NEXT: %3 = stablehlo.add %2, %2 : tensor<2x2xf32> +// CHECK-NEXT: %4 = stablehlo.add %3, %3 : tensor<2x2xf32> +// CHECK-NEXT: return %4 : tensor<2x2xf32> +// CHECK-NEXT: } diff --git a/test/lit_tests/unroll.mlir b/test/lit_tests/unroll.mlir index bbe8c98f6..a2433e0b8 100644 --- a/test/lit_tests/unroll.mlir +++ b/test/lit_tests/unroll.mlir @@ -3,7 +3,6 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.CommFooStr %start = stablehlo.constant dense<0> : tensor %lim = stablehlo.constant dense<5> : tensor From e2e1967261ceb1318605c1a05ecee0a472d50c85 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 4 Dec 2024 19:28:33 -0600 Subject: [PATCH 03/20] Run pass test --- src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp | 8 +++----- src/enzyme_ad/jax/Passes/Passes.h | 1 + .../{com_tests/unroll.mlir => lit_tests/comm_unroll.mlir} | 0 3 files changed, 4 insertions(+), 5 deletions(-) rename test/{com_tests/unroll.mlir => lit_tests/comm_unroll.mlir} (100%) diff --git a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp index dfa432280..d2d362a82 100644 --- a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp +++ b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp @@ -13,11 +13,9 @@ struct CommRemoveFoo : public CommRemoveFooBase { void runOnOperation() override { mlir::Operation* op = getOperation(); - - // if op is CommFoo, erase - if(isa(op)){ - op->erase(); - } + op->walk([](CommFoo foop){ + foop->erase(); + }); } }; diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 08c35b035..ec6a23b5f 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -88,5 +88,6 @@ static void regsiterenzymeXLAPasses() { registerPrintPass(); registerEnzymeHLOOptPass(); registerEnzymeHLOUnrollPass(); + registerCommRemoveFooPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/test/com_tests/unroll.mlir b/test/lit_tests/comm_unroll.mlir similarity index 100% rename from test/com_tests/unroll.mlir rename to test/lit_tests/comm_unroll.mlir From 37845dca1884dbd71ae51543c1ae2f569a8afdad Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Tue, 17 Dec 2024 16:47:47 -0600 Subject: [PATCH 04/20] Added branch, join --- src/enzyme_ad/jax/Dialects/CommDialect.cpp | 6 ++-- src/enzyme_ad/jax/Dialects/CommDialect.h | 1 + src/enzyme_ad/jax/Dialects/CommDialect.td | 38 +++++++++++++++++++++- test/lit_tests/comm_unroll.mlir | 7 +++- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/CommDialect.cpp index 401bbaa56..fad6ea992 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/CommDialect.cpp @@ -7,10 +7,10 @@ using namespace mlir; using namespace mlir::comm; - void CommunicationDialect::initialize() { - addOperations< - CommFoo>(); // Register CommFoo operation + addOperations(); // Register CommFoo operation + addOperations(); + addOperations(); } // CommunicationDialect::CommunicationDialect(mlir::MLIRContext*){} \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.h b/src/enzyme_ad/jax/Dialects/CommDialect.h index 611c35a34..6c60c90f1 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.h +++ b/src/enzyme_ad/jax/Dialects/CommDialect.h @@ -1,4 +1,5 @@ #include "mlir/IR/Dialect.h" +#include "mlir/IR/Builders.h" #include "mlir/Support/TypeID.h" #include "mlir/include/mlir/IR/DialectImplementation.h" #include "src/enzyme_ad/jax/Dialects/CommDialect.h.inc" diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.td b/src/enzyme_ad/jax/Dialects/CommDialect.td index 67141f06b..89ed2c148 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/CommDialect.td @@ -1,6 +1,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/DialectBase.td" +include "mlir/IR/Traits.td" def CommunicationDialect : Dialect { let name = "comm"; @@ -11,7 +12,7 @@ def CommunicationDialect : Dialect { class CommOp traits = []> : Op; -def CommFoo : CommOp<"Foo"> { +def CommFoo : CommOp<"foo"> { let summary = "do-nothing test op"; let arguments = (ins ); let results = (outs ); @@ -19,3 +20,38 @@ def CommFoo : CommOp<"Foo"> { attr-dict }]; } + +def CommJoin : CommOp<"join", traits = [Terminator]> { + let summary = "Denotes the end of a split block, similar to ret for a function"; + let arguments = (ins ); + let results = (outs ); + let assemblyFormat = [{ + attr-dict + }]; +} + +def CommSplitBranch : CommOp<"split_branch"> { + let summary = "One of the branches taken by different devices in a split op."; + let description = [{ + Inside of a split op, this branch will execute on the provided static list of devices. + The code inside of this branch will have access to any communication items declared in + its parent split block, or anything in an outside scope. + }]; + + // Takes in list of participating devices (TODO- currently list of int literals), code region as attributes + let arguments = (ins + I32Attr:$device_list + ); + let regions = (region + AnyRegion:$branch_code // TODO constraint on region? + ); + + let results = (outs ); // nothing? + let assemblyFormat = [{ + attr-dict `<` $device_list `>` $branch_code + }]; +} + +// def CommSplit : CommOp<"Split"> { +// let summary = "The highest level split node in the communication dialect." +// } \ No newline at end of file diff --git a/test/lit_tests/comm_unroll.mlir b/test/lit_tests/comm_unroll.mlir index e04d5cf6c..d7dda4838 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/lit_tests/comm_unroll.mlir @@ -3,7 +3,12 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.Foo + comm.foo + comm.split_branch <1> { + ^start: + comm.foo + comm.join + } %start = stablehlo.constant dense<0> : tensor %lim = stablehlo.constant dense<5> : tensor From 0c3e91b71c9274c6daabe66d50a2f3c478508211 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Tue, 17 Dec 2024 17:43:06 -0600 Subject: [PATCH 05/20] Multiple branch devices --- src/enzyme_ad/jax/BUILD | 6 ++++-- src/enzyme_ad/jax/Dialects/CommDialect.td | 3 ++- test/lit_tests/comm_unroll.mlir | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 8f0f297a5..ca400de1b 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -105,10 +105,10 @@ td_library( gentbl_cc_library( name = "CommDialectIncGen", tbl_outs = [( - ["-gen-dialect-decls"], + ["-gen-dialect-decls", "-dialect=comm"], "Dialects/CommDialect.h.inc", ), ( - ["-gen-dialect-defs"], + ["-gen-dialect-defs", "-dialect=comm"], "Dialects/CommDialect.cpp.inc", ),( ["-gen-op-decls"], @@ -122,6 +122,8 @@ gentbl_cc_library( deps = [ ":CommDialectFiles", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:AttrTdFiles", + "@llvm-project//mlir:BuiltinDialectTdFiles" ], tblgen = "@llvm-project//mlir:mlir-tblgen", ) diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.td b/src/enzyme_ad/jax/Dialects/CommDialect.td index 89ed2c148..d34b2516b 100644 --- a/src/enzyme_ad/jax/Dialects/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/CommDialect.td @@ -1,5 +1,6 @@ include "mlir/IR/OpBase.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/DialectBase.td" include "mlir/IR/Traits.td" @@ -40,7 +41,7 @@ def CommSplitBranch : CommOp<"split_branch"> { // Takes in list of participating devices (TODO- currently list of int literals), code region as attributes let arguments = (ins - I32Attr:$device_list + ArrayAttr:$device_list ); let regions = (region AnyRegion:$branch_code // TODO constraint on region? diff --git a/test/lit_tests/comm_unroll.mlir b/test/lit_tests/comm_unroll.mlir index d7dda4838..d01277874 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/lit_tests/comm_unroll.mlir @@ -4,7 +4,7 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { comm.foo - comm.split_branch <1> { + comm.split_branch <[1, 5]> { ^start: comm.foo comm.join From 073e93562bb9f5f7b25bf6b9a03b292318a7e6f0 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Thu, 16 Jan 2025 17:00:44 -0600 Subject: [PATCH 06/20] Refactor folders --- src/enzyme_ad/jax/BUILD | 34 +++++--- src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp | 9 ++ .../jax/Dialects/Comm/CommDialect.cpp | 18 ++++ src/enzyme_ad/jax/Dialects/Comm/CommDialect.h | 23 +++++ .../jax/Dialects/Comm/CommDialect.td | 87 +++++++++++++++++++ src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 35 ++++++++ src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp | 7 ++ src/enzyme_ad/jax/Dialects/CommDialect.cpp | 16 ---- src/enzyme_ad/jax/Dialects/CommDialect.h | 12 --- src/enzyme_ad/jax/Dialects/CommDialect.td | 58 ------------- src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp | 2 +- src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 2 +- test/lit_tests/comm_unroll.mlir | 12 ++- 13 files changed, 211 insertions(+), 104 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommDialect.h create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommDialect.td create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp delete mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.cpp delete mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.h delete mode 100644 src/enzyme_ad/jax/Dialects/CommDialect.td diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index ca400de1b..a371d66b4 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -94,7 +94,7 @@ gentbl_cc_library( td_library( name = "CommDialectFiles", srcs = [ - "Dialects/CommDialect.td", + "Dialects/Comm/CommDialect.td", ], deps = [ "@llvm-project//mlir:OpBaseTdFiles" @@ -106,19 +106,31 @@ gentbl_cc_library( name = "CommDialectIncGen", tbl_outs = [( ["-gen-dialect-decls", "-dialect=comm"], - "Dialects/CommDialect.h.inc", + "Dialects/Comm/CommDialect.h.inc", ), ( ["-gen-dialect-defs", "-dialect=comm"], - "Dialects/CommDialect.cpp.inc", + "Dialects/Comm/CommDialect.cpp.inc", ),( - ["-gen-op-decls"], - "Dialects/CommOps.h.inc", + ["-gen-op-decls", "-dialect=comm"], + "Dialects/Comm/CommOps.h.inc", ), ( - ["-gen-op-defs"], - "Dialects/CommOps.cpp.inc", + ["-gen-op-defs", "-dialect=comm"], + "Dialects/Comm/CommOps.cpp.inc", + ),( + ["-gen-attrdef-decls", "--attrdefs-dialect=comm"], + "Dialects/Comm/CommAttrs.h.inc", + ),( + ["-gen-attrdef-defs", "--attrdefs-dialect=comm"], + "Dialects/Comm/CommAttrs.cpp.inc", + ),( + ["-gen-typedef-decls", "--typedefs-dialect=comm"], + "Dialects/Comm/CommTypes.h.inc", + ), ( + ["-gen-typedef-defs", "--typedefs-dialect=comm"], + "Dialects/Comm/CommTypes.cpp.inc", ), ], - td_file = "Dialects/CommDialect.td", + td_file = "Dialects/Comm/CommDialect.td", deps = [ ":CommDialectFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -161,8 +173,8 @@ cc_library( cc_library( name = "Dialects", - srcs = glob(["Dialects/*.cpp", "Dialects/CommDialect.cpp.inc"]), - hdrs = glob(["Dialects/*.h"]), + srcs = glob(["Dialects/*.cpp", "Dialects/Comm/*.cpp"]), + hdrs = glob(["Dialects/*.h", "Dialects/Comm/*.h"]), deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -278,13 +290,11 @@ cc_library( [ "Implementations/*.cpp", "Passes/*.cpp", - "Dialects/*.cpp", ], ), hdrs = glob([ "Implementations/*.h", "Passes/*.h", - "Dialects/*.h", ]), copts = [ "-Werror=unused-variable", diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp new file mode 100644 index 000000000..9cb38080f --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp @@ -0,0 +1,9 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "llvm/ADT/TypeSwitch.h" + + +using namespace mlir; +using namespace mlir::comm; + +#define GET_ATTRDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp new file mode 100644 index 000000000..4add08b7d --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -0,0 +1,18 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp.inc" + +using namespace mlir; +using namespace mlir::comm; + + +void CommunicationDialect::initialize() { + addOperations< + #define GET_OP_LIST + #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" + >(); + addAttributes< + #define GET_ATTR_LIST + #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" + >(); + // TODO types when we need them +} diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h new file mode 100644 index 000000000..a2b03bdad --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h @@ -0,0 +1,23 @@ +#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H +#define ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/TypeID.h" +#include "mlir/include/mlir/IR/DialectImplementation.h" +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h.inc" + +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.h.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.h.inc" + +#endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td new file mode 100644 index 000000000..39dfc9194 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -0,0 +1,87 @@ +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/Traits.td" + +def CommunicationDialect : Dialect { + let name = "comm"; + let summary = "A prototype dialect for various communication ops"; + let description = [{}]; + let cppNamespace = "::mlir::comm"; +} + +// Dialect inheritence shortcuts +class CommOp traits = []> : Op; +class CommAttr traits = []> : AttrDef{ + let mnemonic = mnemomic; +} +// class CommType traits = []> : TypeDef { +// let mnemonic = type_mnemonic; +// } + +/* +* Dialect Types +*/ +// def DeviceIdType : CommType<"DeviceId", "device_id"> { +// let summary="Wrapper around int to specify a device from our device set"; +// let parameters=(ins "unsigned":$id); +// let assemblyFormat=[{`d` $id }]; +// } + +/* +* Dialect Attributes +*/ + +// def CommDeviceIdAttr : CommAttr<"DeviceIdAttr", "device_id_attr">{ +// let parameters=(ins DeviceIdType:$id); +// let assemblyFormat = [{ $id }]; +// } + +def CommSplitBranchDescriptor : CommAttr<"SplitBranchDescriptor", "branch_descriptor"> { + let parameters= (ins ArrayRefParameter<"unsigned">:$device_ids); + let assemblyFormat = [{ $device_ids }]; +} + +/* +* Dialect Ops +*/ + +def CommFoo : CommOp<"foo"> { + let summary = "do-nothing test op"; + let arguments = (ins ); + let results = (outs ); + let assemblyFormat = [{ + attr-dict + }]; +} + +// Return, for end of split blocks. We may just be able to use return- lets see if there's any special +// semantics we want join to have +def CommJoin : CommOp<"join", traits = [Terminator]> { + let summary = "Denotes the end of a split block, similar to ret for a function"; + let arguments = (ins ); + let results = (outs ); + let assemblyFormat = [{ + attr-dict + }]; +} + +def CommSplit : CommOp<"split"> { + let summary = "The highest level split node in the communication dialect."; + let description = [{ + Takes in a definition of communication items and a list of split branches for devices to take. + }]; + + let arguments = (ins + ArrayAttr:$branch_metadata // array of attributes of type CommSplitBranchDescriptor. Parallel array with $branches. + ); + let regions = (region + VariadicRegion:$branches // regions for each branch. Parallel array with $branch_metadata + ); + let results = (outs ); + + let assemblyFormat = [{ + attr-dict `{` custom($branch_metadata, $branches) `}` + }]; +} \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp new file mode 100644 index 000000000..659f0ebf1 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -0,0 +1,35 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" + +using namespace mlir; +using namespace mlir::comm; + +// Parsing and printing for the split op branches. Modeled after the SCF +// switchcase parsing code () +static ParseResult +parseSplitBranch(OpAsmParser &p, mlir::ArrayAttr &branch_descriptors, + SmallVectorImpl> &branches) { + SmallVector branch_desc_values; + while (succeeded(p.parseOptionalKeyword("branch"))) { + SplitBranchDescriptorAttr branch_descriptor; + Region ®ion = *branches.emplace_back(std::make_unique()); + if (p.parseAttribute(branch_descriptor) || + p.parseRegion(region, /*arguments=*/{})) + return failure(); + branch_desc_values.push_back(branch_descriptor); + } + branch_descriptors = p.getBuilder().getArrayAttr(branch_desc_values); + return success(); +} + +/// Print the case regions and values. +static void printSplitBranch(OpAsmPrinter &p, Operation *op, + mlir::ArrayAttr cases, RegionRange caseRegions) { + // for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { + // p.printNewline(); + // p << "case " << value << ' '; + // p.printRegion(*region, /*printEntryBlockArgs=*/false); + // } +} + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp new file mode 100644 index 000000000..42fdccc1f --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp @@ -0,0 +1,7 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" + +using namespace mlir; +using namespace mlir::comm; + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/CommDialect.cpp deleted file mode 100644 index fad6ea992..000000000 --- a/src/enzyme_ad/jax/Dialects/CommDialect.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/CommDialect.h" -#include "src/enzyme_ad/jax/Dialects/CommDialect.cpp.inc" - -#define GET_OP_CLASSES -#include "src/enzyme_ad/jax/Dialects/CommOps.cpp.inc" - -using namespace mlir; -using namespace mlir::comm; - -void CommunicationDialect::initialize() { - addOperations(); // Register CommFoo operation - addOperations(); - addOperations(); -} - -// CommunicationDialect::CommunicationDialect(mlir::MLIRContext*){} \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.h b/src/enzyme_ad/jax/Dialects/CommDialect.h deleted file mode 100644 index 6c60c90f1..000000000 --- a/src/enzyme_ad/jax/Dialects/CommDialect.h +++ /dev/null @@ -1,12 +0,0 @@ -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/Support/TypeID.h" -#include "mlir/include/mlir/IR/DialectImplementation.h" -#include "src/enzyme_ad/jax/Dialects/CommDialect.h.inc" - -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Dialect.h" - -#define GET_OP_CLASSES -#include "src/enzyme_ad/jax/Dialects/CommOps.h.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/CommDialect.td b/src/enzyme_ad/jax/Dialects/CommDialect.td deleted file mode 100644 index d34b2516b..000000000 --- a/src/enzyme_ad/jax/Dialects/CommDialect.td +++ /dev/null @@ -1,58 +0,0 @@ -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinAttributes.td" -include "mlir/IR/DialectBase.td" -include "mlir/IR/Traits.td" - -def CommunicationDialect : Dialect { - let name = "comm"; - let summary = "A prototype dialect for various communication ops"; - let description = [{}]; - let cppNamespace = "::mlir::comm"; -} - -class CommOp traits = []> : Op; - -def CommFoo : CommOp<"foo"> { - let summary = "do-nothing test op"; - let arguments = (ins ); - let results = (outs ); - let assemblyFormat = [{ - attr-dict - }]; -} - -def CommJoin : CommOp<"join", traits = [Terminator]> { - let summary = "Denotes the end of a split block, similar to ret for a function"; - let arguments = (ins ); - let results = (outs ); - let assemblyFormat = [{ - attr-dict - }]; -} - -def CommSplitBranch : CommOp<"split_branch"> { - let summary = "One of the branches taken by different devices in a split op."; - let description = [{ - Inside of a split op, this branch will execute on the provided static list of devices. - The code inside of this branch will have access to any communication items declared in - its parent split block, or anything in an outside scope. - }]; - - // Takes in list of participating devices (TODO- currently list of int literals), code region as attributes - let arguments = (ins - ArrayAttr:$device_list - ); - let regions = (region - AnyRegion:$branch_code // TODO constraint on region? - ); - - let results = (outs ); // nothing? - let assemblyFormat = [{ - attr-dict `<` $device_list `>` $branch_code - }]; -} - -// def CommSplit : CommOp<"Split"> { -// let summary = "The highest level split node in the communication dialect." -// } \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp index d2d362a82..fd2652bd5 100644 --- a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp +++ b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialects/CommDialect.h" +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 820ad921b..e7137c071 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -47,7 +47,7 @@ #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "Dialects/CommDialect.h" +#include "Dialects/Comm/CommDialect.h" diff --git a/test/lit_tests/comm_unroll.mlir b/test/lit_tests/comm_unroll.mlir index d01277874..53726525b 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/lit_tests/comm_unroll.mlir @@ -4,10 +4,14 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { comm.foo - comm.split_branch <[1, 5]> { - ^start: - comm.foo - comm.join + comm.split { + branch [d1, d4] + ^start: + comm.foo + comm.join + branch [d2] + ^start: + comm.join } %start = stablehlo.constant dense<0> : tensor From 363c6c08581b783cb9b9ef31b2be14f33df36abc Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 20 Jan 2025 12:45:51 -0600 Subject: [PATCH 07/20] Custom parsers and proper attribute registration --- src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp | 9 --- .../jax/Dialects/Comm/CommDialect.cpp | 67 ++++++++++++++++--- src/enzyme_ad/jax/Dialects/Comm/CommDialect.h | 4 +- .../jax/Dialects/Comm/CommDialect.td | 3 +- src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 29 ++++++-- test/lit_tests/comm_unroll.mlir | 8 ++- 6 files changed, 93 insertions(+), 27 deletions(-) delete mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp deleted file mode 100644 index 9cb38080f..000000000 --- a/src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp +++ /dev/null @@ -1,9 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" -#include "llvm/ADT/TypeSwitch.h" - - -using namespace mlir; -using namespace mlir::comm; - -#define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp index 4add08b7d..8cca2123c 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -1,18 +1,69 @@ #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp.inc" +// Attr imports +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" // for dbgs + using namespace mlir; using namespace mlir::comm; - void CommunicationDialect::initialize() { - addOperations< - #define GET_OP_LIST - #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" - >(); addAttributes< - #define GET_ATTR_LIST - #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" - >(); +#define GET_ATTRDEF_LIST +#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" + >(); + // TODO types when we need them } + +/** + * Attribute implemenation included is needed for the addAttributes<> template + * to succeed, so for now it is easier to put attribute implementation stuff + * here. + */ + +/** + * In its current form this almost certainly did not need to be a custom parser- + * wrote this thinking I would add more to it. + */ +static ParseResult parseSplitBranchDescriptor(AsmParser &p, + ArrayRef devices) { + + llvm::dbgs() << "Starting branch descriptor custom parse\n"; + + llvm::SmallVector dev_set; + do { + // Do while since list shouldn't be empty + unsigned &id = dev_set.emplace_back(); + auto parse_id = p.parseInteger(id); + // Check for parse error + if (parse_id) { + return failure(); + } + llvm::dbgs() << "Parsed " << id << "\n"; + } while (succeeded(p.parseOptionalComma())); + llvm::dbgs() << "Ending branch descriptor custom parse\n"; + + devices = dev_set; + llvm::dbgs() << "Returning success\n"; + return success(); +} + +/// Print the case regions and values. +static void printSplitBranchDescriptor(AsmPrinter &p, + llvm::ArrayRef devices) { + // for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { + // p.printNewline(); + // p << "case " << value << ' '; + // p.printRegion(*region, /*printEntryBlockArgs=*/false); + // } +} + +#define GET_ATTRDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h index a2b03bdad..2a2e21e5d 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h @@ -7,8 +7,8 @@ #include "mlir/include/mlir/IR/DialectImplementation.h" #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h.inc" -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Dialect.h" #define GET_TYPEDEF_CLASSES diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 39dfc9194..31d614591 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -40,7 +40,8 @@ class CommAttr traits = []> : AttrDef< def CommSplitBranchDescriptor : CommAttr<"SplitBranchDescriptor", "branch_descriptor"> { let parameters= (ins ArrayRefParameter<"unsigned">:$device_ids); - let assemblyFormat = [{ $device_ids }]; + let assemblyFormat = [{ `{` custom($device_ids) `}` }]; + // let assemblyFormat = [{ $device_ids }]; } /* diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp index 659f0ebf1..b9c5b53a6 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -1,4 +1,5 @@ #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "llvm/Support/Debug.h" // for dbgs using namespace mlir; using namespace mlir::comm; @@ -9,13 +10,33 @@ static ParseResult parseSplitBranch(OpAsmParser &p, mlir::ArrayAttr &branch_descriptors, SmallVectorImpl> &branches) { SmallVector branch_desc_values; + llvm::dbgs() << "Looking for split branch\n"; while (succeeded(p.parseOptionalKeyword("branch"))) { - SplitBranchDescriptorAttr branch_descriptor; + llvm::dbgs() << "Found branch kw\n"; Region ®ion = *branches.emplace_back(std::make_unique()); - if (p.parseAttribute(branch_descriptor) || - p.parseRegion(region, /*arguments=*/{})) + + llvm::dbgs() << "Parsing branch description...\n"; + + SplitBranchDescriptorAttr branch_descriptor; + auto descriptor_parse_flag = + p.parseCustomAttributeWithFallback(branch_descriptor); + + llvm::dbgs() << "... done\n"; + if (!descriptor_parse_flag && + branch_descriptor.isa()) { + branch_desc_values.push_back( + branch_descriptor.cast()); + } else { + llvm::dbgs() << "Failed to parse branch descriptor\n"; + return failure(); + } + + llvm::dbgs() << "Parsing branch region...\n"; + auto parse_region = p.parseRegion(region, /*arguments=*/{}); + if (parse_region) { + llvm::dbgs() << "Failed to parse branch region\n"; return failure(); - branch_desc_values.push_back(branch_descriptor); + } } branch_descriptors = p.getBuilder().getArrayAttr(branch_desc_values); return success(); diff --git a/test/lit_tests/comm_unroll.mlir b/test/lit_tests/comm_unroll.mlir index 53726525b..b3ea97851 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/lit_tests/comm_unroll.mlir @@ -4,14 +4,16 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { comm.foo - comm.split { - branch [d1, d4] + comm.split {} { + branch {1, 4} { ^start: comm.foo comm.join - branch [d2] + } + branch {2} { ^start: comm.join + } } %start = stablehlo.constant dense<0> : tensor From d10b56e1d7e0cb0f3bdd64d8e55ba95908549ecd Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 20 Jan 2025 14:35:24 -0600 Subject: [PATCH 08/20] Printing --- .../jax/Dialects/Comm/CommDialect.cpp | 19 ++++++--------- .../jax/Dialects/Comm/CommDialect.td | 10 ++++++-- src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 23 ++++++++----------- test/lit_tests/comm_unroll.mlir | 4 ++-- 4 files changed, 26 insertions(+), 30 deletions(-) diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp index 8cca2123c..e89a39590 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -4,7 +4,6 @@ // Attr imports #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" // for dbgs using namespace mlir; using namespace mlir::comm; @@ -33,9 +32,7 @@ void CommunicationDialect::initialize() { * wrote this thinking I would add more to it. */ static ParseResult parseSplitBranchDescriptor(AsmParser &p, - ArrayRef devices) { - - llvm::dbgs() << "Starting branch descriptor custom parse\n"; + llvm::SmallVector &devices) { llvm::SmallVector dev_set; do { @@ -46,23 +43,21 @@ static ParseResult parseSplitBranchDescriptor(AsmParser &p, if (parse_id) { return failure(); } - llvm::dbgs() << "Parsed " << id << "\n"; } while (succeeded(p.parseOptionalComma())); - llvm::dbgs() << "Ending branch descriptor custom parse\n"; devices = dev_set; - llvm::dbgs() << "Returning success\n"; return success(); } /// Print the case regions and values. static void printSplitBranchDescriptor(AsmPrinter &p, llvm::ArrayRef devices) { - // for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { - // p.printNewline(); - // p << "case " << value << ' '; - // p.printRegion(*region, /*printEntryBlockArgs=*/false); - // } + for (int i = 0; i < devices.size(); i++) { + if (i > 0) { + p << ", "; + } + p << devices[i]; + } } #define GET_ATTRDEF_CLASSES diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 31d614591..489e89f0e 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -9,6 +9,7 @@ def CommunicationDialect : Dialect { let summary = "A prototype dialect for various communication ops"; let description = [{}]; let cppNamespace = "::mlir::comm"; + let useDefaultAttributePrinterParser = 1; } // Dialect inheritence shortcuts @@ -40,8 +41,13 @@ class CommAttr traits = []> : AttrDef< def CommSplitBranchDescriptor : CommAttr<"SplitBranchDescriptor", "branch_descriptor"> { let parameters= (ins ArrayRefParameter<"unsigned">:$device_ids); - let assemblyFormat = [{ `{` custom($device_ids) `}` }]; - // let assemblyFormat = [{ $device_ids }]; + let assemblyFormat = [{ `(` custom($device_ids) `)` }]; + let description = [{ + Attribute for describing a split branch. Currently holds only the device ids corresponding to the branch. + + Syntax: + (deviceid[, deviceid]*) + }]; } /* diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp index b9c5b53a6..1d5d5a7c8 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -1,5 +1,4 @@ #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" -#include "llvm/Support/Debug.h" // for dbgs using namespace mlir; using namespace mlir::comm; @@ -10,31 +9,23 @@ static ParseResult parseSplitBranch(OpAsmParser &p, mlir::ArrayAttr &branch_descriptors, SmallVectorImpl> &branches) { SmallVector branch_desc_values; - llvm::dbgs() << "Looking for split branch\n"; while (succeeded(p.parseOptionalKeyword("branch"))) { - llvm::dbgs() << "Found branch kw\n"; Region ®ion = *branches.emplace_back(std::make_unique()); - llvm::dbgs() << "Parsing branch description...\n"; - SplitBranchDescriptorAttr branch_descriptor; auto descriptor_parse_flag = p.parseCustomAttributeWithFallback(branch_descriptor); - llvm::dbgs() << "... done\n"; if (!descriptor_parse_flag && branch_descriptor.isa()) { branch_desc_values.push_back( branch_descriptor.cast()); } else { - llvm::dbgs() << "Failed to parse branch descriptor\n"; return failure(); } - llvm::dbgs() << "Parsing branch region...\n"; auto parse_region = p.parseRegion(region, /*arguments=*/{}); if (parse_region) { - llvm::dbgs() << "Failed to parse branch region\n"; return failure(); } } @@ -45,11 +36,15 @@ parseSplitBranch(OpAsmParser &p, mlir::ArrayAttr &branch_descriptors, /// Print the case regions and values. static void printSplitBranch(OpAsmPrinter &p, Operation *op, mlir::ArrayAttr cases, RegionRange caseRegions) { - // for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { - // p.printNewline(); - // p << "case " << value << ' '; - // p.printRegion(*region, /*printEntryBlockArgs=*/false); - // } + p.increaseIndent(); + for (auto [descriptor, region] : llvm::zip(cases, caseRegions)) { + p.printNewline(); + p << "branch "; + descriptor.cast().print(p); + p << " "; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } + p.decreaseIndent(); } #define GET_OP_CLASSES diff --git a/test/lit_tests/comm_unroll.mlir b/test/lit_tests/comm_unroll.mlir index b3ea97851..123847948 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/lit_tests/comm_unroll.mlir @@ -5,12 +5,12 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { comm.foo comm.split {} { - branch {1, 4} { + branch (1, 4) { ^start: comm.foo comm.join } - branch {2} { + branch (2) { ^start: comm.join } From b2dded293921b96c132422dfa00b10005e43283e Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 20 Jan 2025 15:38:56 -0600 Subject: [PATCH 09/20] Added communicaton tokens --- .../jax/Dialects/Comm/CommDialect.cpp | 16 ++++++--- .../jax/Dialects/Comm/CommDialect.td | 35 ++++++++++++++----- src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp | 7 ---- .../basic_test.mlir} | 13 ++----- 4 files changed, 41 insertions(+), 30 deletions(-) delete mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp rename test/{lit_tests/comm_unroll.mlir => comm_tests/basic_test.mlir} (68%) diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp index e89a39590..25ef17723 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -1,7 +1,5 @@ #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp.inc" - -// Attr imports #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -9,6 +7,11 @@ using namespace mlir; using namespace mlir::comm; void CommunicationDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" + >(); + addAttributes< #define GET_ATTRDEF_LIST #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" @@ -31,8 +34,8 @@ void CommunicationDialect::initialize() { * In its current form this almost certainly did not need to be a custom parser- * wrote this thinking I would add more to it. */ -static ParseResult parseSplitBranchDescriptor(AsmParser &p, - llvm::SmallVector &devices) { +static ParseResult +parseSplitBranchDescriptor(AsmParser &p, llvm::SmallVector &devices) { llvm::SmallVector dev_set; do { @@ -61,4 +64,7 @@ static void printSplitBranchDescriptor(AsmPrinter &p, } #define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" \ No newline at end of file +#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 489e89f0e..e86b031d0 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -17,18 +17,23 @@ class CommOp traits = []> : Op traits = []> : AttrDef{ let mnemonic = mnemomic; } -// class CommType traits = []> : TypeDef { -// let mnemonic = type_mnemonic; -// } +class CommType traits = []> : TypeDef { + let mnemonic = type_mnemonic; +} /* * Dialect Types */ -// def DeviceIdType : CommType<"DeviceId", "device_id"> { -// let summary="Wrapper around int to specify a device from our device set"; -// let parameters=(ins "unsigned":$id); -// let assemblyFormat=[{`d` $id }]; -// } +def DeviceIdType : CommType<"DeviceId", "device_id"> { + let summary="Wrapper around int to specify a device from our device set"; + let parameters=(ins "unsigned":$id); + let assemblyFormat=[{`d` $id }]; +} +def MessageTokenType : CommType<"MessageToken", "token"> { + let summary = "Represents a consumable message token"; + // let parameters = (ins "Type":$msg_type); + +} /* * Dialect Attributes @@ -91,4 +96,18 @@ def CommSplit : CommOp<"split"> { let assemblyFormat = [{ attr-dict `{` custom($branch_metadata, $branches) `}` }]; +} + +// Message types. In the future we will likely want to have a common base class +def CommSimpleMessage: CommOp<"simple_msg"> { + let summary = "A simple single-usage, one-way message token"; + let arguments = (ins + TypeAttr:$msg_type + ); + let results = (outs + MessageTokenType:$token + ); + let assemblyFormat = [{ + attr-dict $msg_type + }]; } \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp deleted file mode 100644 index 42fdccc1f..000000000 --- a/src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" - -using namespace mlir; -using namespace mlir::comm; - -#define GET_TYPEDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/test/lit_tests/comm_unroll.mlir b/test/comm_tests/basic_test.mlir similarity index 68% rename from test/lit_tests/comm_unroll.mlir rename to test/comm_tests/basic_test.mlir index 123847948..743ddc878 100644 --- a/test/lit_tests/comm_unroll.mlir +++ b/test/comm_tests/basic_test.mlir @@ -4,6 +4,8 @@ module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { comm.foo + + %msg = comm.simple_msg tensor<2x2xf32> comm.split {} { branch (1, 4) { ^start: @@ -32,13 +34,4 @@ module { } return %w#0 : tensor<2x2xf32> } -} - -// CHECK: func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { -// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<2x2xf32> -// CHECK-NEXT: %1 = stablehlo.add %0, %0 : tensor<2x2xf32> -// CHECK-NEXT: %2 = stablehlo.add %1, %1 : tensor<2x2xf32> -// CHECK-NEXT: %3 = stablehlo.add %2, %2 : tensor<2x2xf32> -// CHECK-NEXT: %4 = stablehlo.add %3, %3 : tensor<2x2xf32> -// CHECK-NEXT: return %4 : tensor<2x2xf32> -// CHECK-NEXT: } +} \ No newline at end of file From f6573a412599d4b69cbf78b5205637d9f55b79cd Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 22 Jan 2025 11:02:46 -0600 Subject: [PATCH 10/20] Encode split branches, messages as ops in a single block --- .../jax/Dialects/Comm/CommDialect.cpp | 40 -------------- .../jax/Dialects/Comm/CommDialect.td | 53 +++++++++---------- src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 44 --------------- test/comm_tests/refactor_test.mlir | 46 ++++++++++++++++ 4 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 test/comm_tests/refactor_test.mlir diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp index 25ef17723..a560f05a1 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -20,48 +20,8 @@ void CommunicationDialect::initialize() { #define GET_OP_LIST #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" >(); - - // TODO types when we need them } -/** - * Attribute implemenation included is needed for the addAttributes<> template - * to succeed, so for now it is easier to put attribute implementation stuff - * here. - */ - -/** - * In its current form this almost certainly did not need to be a custom parser- - * wrote this thinking I would add more to it. - */ -static ParseResult -parseSplitBranchDescriptor(AsmParser &p, llvm::SmallVector &devices) { - - llvm::SmallVector dev_set; - do { - // Do while since list shouldn't be empty - unsigned &id = dev_set.emplace_back(); - auto parse_id = p.parseInteger(id); - // Check for parse error - if (parse_id) { - return failure(); - } - } while (succeeded(p.parseOptionalComma())); - - devices = dev_set; - return success(); -} - -/// Print the case regions and values. -static void printSplitBranchDescriptor(AsmPrinter &p, - llvm::ArrayRef devices) { - for (int i = 0; i < devices.size(); i++) { - if (i > 0) { - p << ", "; - } - p << devices[i]; - } -} #define GET_ATTRDEF_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index e86b031d0..9dea6b759 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -9,7 +9,6 @@ def CommunicationDialect : Dialect { let summary = "A prototype dialect for various communication ops"; let description = [{}]; let cppNamespace = "::mlir::comm"; - let useDefaultAttributePrinterParser = 1; } // Dialect inheritence shortcuts @@ -35,26 +34,6 @@ def MessageTokenType : CommType<"MessageToken", "token"> { } -/* -* Dialect Attributes -*/ - -// def CommDeviceIdAttr : CommAttr<"DeviceIdAttr", "device_id_attr">{ -// let parameters=(ins DeviceIdType:$id); -// let assemblyFormat = [{ $id }]; -// } - -def CommSplitBranchDescriptor : CommAttr<"SplitBranchDescriptor", "branch_descriptor"> { - let parameters= (ins ArrayRefParameter<"unsigned">:$device_ids); - let assemblyFormat = [{ `(` custom($device_ids) `)` }]; - let description = [{ - Attribute for describing a split branch. Currently holds only the device ids corresponding to the branch. - - Syntax: - (deviceid[, deviceid]*) - }]; -} - /* * Dialect Ops */ @@ -79,22 +58,38 @@ def CommJoin : CommOp<"join", traits = [Terminator]> { }]; } -def CommSplit : CommOp<"split"> { +def CommSplit : CommOp<"split", traits = [SingleBlock, NoTerminator]> { let summary = "The highest level split node in the communication dialect."; let description = [{ Takes in a definition of communication items and a list of split branches for devices to take. + Encoded as a single-block no-terminator region that consists only of branches and communcation token declarations. + Example syntax: + comm.split { + %1 = comm.simple_msg msg_type + comm.branch [1, 4] { + // ... comm branch region + } + comm.branch [2] { + // ... comm branch region + } + } }]; - let arguments = (ins - ArrayAttr:$branch_metadata // array of attributes of type CommSplitBranchDescriptor. Parallel array with $branches. - ); - let regions = (region - VariadicRegion:$branches // regions for each branch. Parallel array with $branch_metadata - ); + let arguments = (ins ); // no inputs yet, encoded in the region + let regions = (region SizedRegion<1>:$declarations); let results = (outs ); let assemblyFormat = [{ - attr-dict `{` custom($branch_metadata, $branches) `}` + $declarations attr-dict + }]; +} + +def CommBranch : CommOp<"branch"> { + let summary = "Represents one branch that can be taken by a split node"; + let arguments = (ins DenseI32ArrayAttr:$device_ids); + let regions = (region AnyRegion:$region); + let assemblyFormat = [{ + attr-dict $device_ids $region }]; } diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp index 1d5d5a7c8..81eaa3954 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -3,49 +3,5 @@ using namespace mlir; using namespace mlir::comm; -// Parsing and printing for the split op branches. Modeled after the SCF -// switchcase parsing code () -static ParseResult -parseSplitBranch(OpAsmParser &p, mlir::ArrayAttr &branch_descriptors, - SmallVectorImpl> &branches) { - SmallVector branch_desc_values; - while (succeeded(p.parseOptionalKeyword("branch"))) { - Region ®ion = *branches.emplace_back(std::make_unique()); - - SplitBranchDescriptorAttr branch_descriptor; - auto descriptor_parse_flag = - p.parseCustomAttributeWithFallback(branch_descriptor); - - if (!descriptor_parse_flag && - branch_descriptor.isa()) { - branch_desc_values.push_back( - branch_descriptor.cast()); - } else { - return failure(); - } - - auto parse_region = p.parseRegion(region, /*arguments=*/{}); - if (parse_region) { - return failure(); - } - } - branch_descriptors = p.getBuilder().getArrayAttr(branch_desc_values); - return success(); -} - -/// Print the case regions and values. -static void printSplitBranch(OpAsmPrinter &p, Operation *op, - mlir::ArrayAttr cases, RegionRange caseRegions) { - p.increaseIndent(); - for (auto [descriptor, region] : llvm::zip(cases, caseRegions)) { - p.printNewline(); - p << "branch "; - descriptor.cast().print(p); - p << " "; - p.printRegion(*region, /*printEntryBlockArgs=*/false); - } - p.decreaseIndent(); -} - #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" \ No newline at end of file diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir new file mode 100644 index 000000000..d209e638c --- /dev/null +++ b/test/comm_tests/refactor_test.mlir @@ -0,0 +1,46 @@ +// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s + +module { + + func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + comm.foo + + comm.split { + %msg = comm.simple_msg tensor<2x2xf32> + comm.branch [1, 4] { + ^start: + comm.foo + comm.split { + comm.branch [1] { + comm.join + } + comm.branch [4] { + comm.foo + comm.join + } + } + comm.join + } + comm.branch [2] { + ^start: + comm.join + } + } + %start = stablehlo.constant dense<0> : tensor + + %lim = stablehlo.constant dense<5> : tensor + + %step = stablehlo.constant dense<1> : tensor + + %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + return %w#0 : tensor<2x2xf32> + } +} \ No newline at end of file From 7daa1b6599b832e7ae84accf5e7cb3114293ec96 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 22 Jan 2025 16:11:54 -0600 Subject: [PATCH 11/20] Added verifiers and getters --- src/enzyme_ad/jax/Dialects/Comm/CommDialect.h | 16 +++++++ .../jax/Dialects/Comm/CommDialect.td | 33 ++++++++++++- src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 26 ++++++++++ test/comm_tests/branch_duplicated_device.mlir | 48 +++++++++++++++++++ test/comm_tests/branch_out_of_split.mlir | 47 ++++++++++++++++++ test/comm_tests/extra_op_in_split.mlir | 47 ++++++++++++++++++ test/comm_tests/refactor_test.mlir | 3 +- 7 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 test/comm_tests/branch_duplicated_device.mlir create mode 100644 test/comm_tests/branch_out_of_split.mlir create mode 100644 test/comm_tests/extra_op_in_split.mlir diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h index 2a2e21e5d..f15c4f263 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h @@ -11,6 +11,14 @@ #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Dialect.h" +namespace mlir::comm { +template +class SplitMemberOp : public OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op); +}; +} // namespace mlir::comm + #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.h.inc" @@ -20,4 +28,12 @@ #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.h.inc" +template +mlir::LogicalResult mlir::comm::SplitMemberOp::verifyTrait(Operation *op) { + if (!isa(op->getParentOp())) { + return op->emitOpError("must be located as immediate child of split op"); + } + return success(); +} + #endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 9dea6b759..3fd80abed 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -34,6 +34,23 @@ def MessageTokenType : CommType<"MessageToken", "token"> { } +/* +* Dialect traits +*/ +def CommSplitMemberOpTrait : NativeOpTrait<"SplitMemberOp", + /*traits=*/[], + /*extraOpDeclaration = */[{ + mlir::comm::CommSplit getParentSplit(); + }], + /*extraOpDefinition = */[{ + mlir::comm::CommSplit $cppClass::getParentSplit(){ + // Verifier checks that this is indeed of the correct type + return dyn_cast(getOperation()->getParentOp()); + } + }] +>{ + let cppNamespace = "::mlir::comm"; +} /* * Dialect Ops */ @@ -82,9 +99,21 @@ def CommSplit : CommOp<"split", traits = [SingleBlock, NoTerminator]> { let assemblyFormat = [{ $declarations attr-dict }]; + + let hasVerifier = 1; + + // Add some convenience getters to hide the mess around having a declarations region + let extraClassDeclaration = [{ + auto getMessages() { + return getDeclarations().getOps<::mlir::comm::CommSimpleMessage>(); + } + auto getBranches() { + return getDeclarations().getOps<::mlir::comm::CommBranch>(); + } + }]; } -def CommBranch : CommOp<"branch"> { +def CommBranch : CommOp<"branch", traits = [CommSplitMemberOpTrait]> { let summary = "Represents one branch that can be taken by a split node"; let arguments = (ins DenseI32ArrayAttr:$device_ids); let regions = (region AnyRegion:$region); @@ -94,7 +123,7 @@ def CommBranch : CommOp<"branch"> { } // Message types. In the future we will likely want to have a common base class -def CommSimpleMessage: CommOp<"simple_msg"> { +def CommSimpleMessage: CommOp<"simple_msg", traits = [CommSplitMemberOpTrait]> { let summary = "A simple single-usage, one-way message token"; let arguments = (ins TypeAttr:$msg_type diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp index 81eaa3954..e1df33ea1 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -1,7 +1,33 @@ +#include "mlir/Support/LogicalResult.h" #include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "llvm/ADT/DenseSet.h" + + using namespace mlir; using namespace mlir::comm; + +LogicalResult CommSplit::verify() { + for(Operation &op : getDeclarations().getOps()){ + // Check that all ops are allowable as members + if (!op.hasTrait()){ + return op.emitOpError("not allowed as immediate split op member"); + } + + // check that all branches have disjoint device sets + DenseSet used_devices; + for(CommBranch branch : getBranches()){ + for(unsigned device : branch.getDeviceIds()){ + if (used_devices.contains(device)){ + return branch.emitError("uses device already accounted for in same split"); + } + used_devices.insert(device); + } + } + } + return success(); +} + #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" \ No newline at end of file diff --git a/test/comm_tests/branch_duplicated_device.mlir b/test/comm_tests/branch_duplicated_device.mlir new file mode 100644 index 000000000..b9d17de3d --- /dev/null +++ b/test/comm_tests/branch_duplicated_device.mlir @@ -0,0 +1,48 @@ +// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s + +module { + + func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + comm.foo + + comm.split { + %msg = comm.simple_msg tensor<2x2xf32> + comm.branch [1, 4] { + ^start: + comm.foo + comm.split { + comm.branch [1] { + comm.join + } + comm.branch [4] { + comm.foo + comm.join + } + } + comm.join + } + comm.branch [2, 4] { + ^start: + comm.join + } + } + + + %start = stablehlo.constant dense<0> : tensor + + %lim = stablehlo.constant dense<5> : tensor + + %step = stablehlo.constant dense<1> : tensor + + %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + return %w#0 : tensor<2x2xf32> + } +} \ No newline at end of file diff --git a/test/comm_tests/branch_out_of_split.mlir b/test/comm_tests/branch_out_of_split.mlir new file mode 100644 index 000000000..38ebf5ef3 --- /dev/null +++ b/test/comm_tests/branch_out_of_split.mlir @@ -0,0 +1,47 @@ +// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s + +module { + + func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + comm.foo + + comm.split { + %msg = comm.simple_msg tensor<2x2xf32> + comm.branch [1, 4] { + ^start: + comm.foo + comm.split { + comm.branch [1] { + comm.join + } + comm.branch [4] { + comm.foo + comm.join + } + } + comm.join + } + } + comm.branch [2] { + ^start: + comm.join + } + + %start = stablehlo.constant dense<0> : tensor + + %lim = stablehlo.constant dense<5> : tensor + + %step = stablehlo.constant dense<1> : tensor + + %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + return %w#0 : tensor<2x2xf32> + } +} \ No newline at end of file diff --git a/test/comm_tests/extra_op_in_split.mlir b/test/comm_tests/extra_op_in_split.mlir new file mode 100644 index 000000000..c0e0dd6ab --- /dev/null +++ b/test/comm_tests/extra_op_in_split.mlir @@ -0,0 +1,47 @@ +// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s + +module { + + func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + comm.foo + + comm.split { + %msg = comm.simple_msg tensor<2x2xf32> + comm.branch [1, 4] { + ^start: + comm.foo + comm.split { + comm.branch [1] { + comm.join + } + comm.branch [4] { + comm.foo + comm.join + } + comm.foo + } + comm.join + } + comm.branch [2] { + ^start: + comm.join + } + } + %start = stablehlo.constant dense<0> : tensor + + %lim = stablehlo.constant dense<5> : tensor + + %step = stablehlo.constant dense<1> : tensor + + %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + return %w#0 : tensor<2x2xf32> + } +} \ No newline at end of file diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir index d209e638c..4f58c2822 100644 --- a/test/comm_tests/refactor_test.mlir +++ b/test/comm_tests/refactor_test.mlir @@ -10,7 +10,8 @@ module { comm.branch [1, 4] { ^start: comm.foo - comm.split { + comm.split { + %msg2 = comm.simple_msg f32 comm.branch [1] { comm.join } From 5fdfc0fea3ad3f4e9fccbc449ee1f47fefced97c Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 22 Jan 2025 18:02:11 -0600 Subject: [PATCH 12/20] Explode split block branches pass --- .../jax/Passes/CommExplodeSplits.cpp | 119 ++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 5 + src/enzyme_ad/jax/Passes/Passes.td | 7 ++ test/comm_tests/refactor_test.mlir | 2 + 4 files changed, 133 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp diff --git a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp new file mode 100644 index 000000000..d7b1f516d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp @@ -0,0 +1,119 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::comm; +using namespace enzyme; +using namespace mlir::enzyme; // one of the upstream includes we need is wrapped + // in this namespace + +namespace { +struct CommExplodeSplits : public CommExplodeSplitsBase { + + /** + * After a branch is split and contains only one device, look through it for + * split nodes and inline them. + */ + static void inlineSplits(CommBranch branch) { + assert(branch.getDeviceIds().size() == 1 && + "Shouldn't inline on branch with multiple devices"); + + for (CommSplit subsplit : llvm::to_vector(branch.getOps())) { + // locate the branch corresponding to the current one + for (CommBranch subbranch : subsplit.getBranches()) { + assert(subbranch.getDeviceIds().size() == 1 && + "Sub splits should already have been exploded"); + if (!(subbranch.getDeviceIds().front() == + branch.getDeviceIds().front())) + continue; + + // We've found our branch. Now we want to copy its code up in place of + // the split branch. If we only have one basic block this is easy. + if (subbranch.getRegion().hasOneBlock()) { + // Can simply copy all instructions before split and then erase it + Block &block = subbranch.getRegion().front(); + while (!block.empty()) { + // TODO what about other terminators? + auto &op = block.front(); + if (isa(op)) { + op.erase(); + } else { + op.moveBefore(subsplit); + } + } + subsplit.erase(); + + } else { + assert(0 && "TODO inline multiple blocks"); + // TODO split the basic block, add a jump to entry, and link joins to next statement. + } + + break; + } + } + } + + /** + * Functor to perform the explode split, i.e. create a distinct branch for + * each device under a single top-level split. Must be called as a post-order + * traversal so that all sub-splits are already exploded. + */ + static void explodeSplit(CommSplit split) { + + for (CommBranch branch : llvm::to_vector(split.getBranches())) { + llvm::dbgs() << "Running on branch " << branch << "\n"; + // Look for any subsplits and move their message declarations up + branch.walk([&](CommSplit subsplit) { + for (auto msg : llvm::to_vector(subsplit.getMessages())) { + msg->moveBefore(branch); + } + }); + + // Copy the branch for each additional device. For the last one we can + // just mutate instead of cloning. + for (int i = 0; i < branch.getDeviceIds().size() - 1; i++) { + llvm::dbgs() << "Creating deep clone of branch\n"; + CommBranch cloned = branch.clone(); + + llvm::dbgs() << "Setting branch ids\n"; + cloned.setDeviceIds({branch.getDeviceIds()[i]}); + llvm::dbgs() << "Moving branch into split block at loc " + << branch.getLoc() << "\n"; + // can't use insertBefore/After immediately since it segfaults trying to + // unlink. + branch->getBlock()->push_back(cloned); + cloned->moveBefore(branch); + inlineSplits(cloned); + } + + branch.setDeviceIds({branch.getDeviceIds().back()}); + inlineSplits(branch); + + llvm::dbgs() << "Done with current branch\n"; + } + + llvm::dbgs() << "Done with all branches\n"; + } + + void runOnOperation() override { + llvm::dbgs() << "Running pass on op " << getOperation() << "\n"; + getOperation()->walk(explodeSplit); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace comm { +std::unique_ptr createCommExplodeSplitsPass() { + llvm::dbgs() << "creating explode pass\n"; + return std::make_unique(); +} +} // namespace comm +} // namespace mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index ec6a23b5f..977d7aaac 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -12,6 +12,9 @@ #include "mlir/Pass/Pass.h" #include +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" + + namespace mlir { class PatternRewriter; class RewritePatternSet; @@ -24,6 +27,7 @@ std::unique_ptr createPrintPass(); } // namespace enzyme namespace comm { std::unique_ptr createCommRemoveFooPass(); +std::unique_ptr createCommExplodeSplitsPass(); } } // namespace mlir @@ -89,5 +93,6 @@ static void regsiterenzymeXLAPasses() { registerEnzymeHLOOptPass(); registerEnzymeHLOUnrollPass(); registerCommRemoveFooPass(); + registerCommExplodeSplitsPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index ce1346323..961d82d57 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -20,6 +20,13 @@ def CommRemoveFoo : Pass<"remove-comm-foo"> { let constructor = "mlir::comm::createCommRemoveFooPass()"; } +def CommExplodeSplits : Pass<"explode-comm-splits"> { + let summary = "Converts a (nested) split into a non-nested split with one branch per device"; + let dependentDialects = [ + "comm::CommunicationDialect" + ]; + let constructor = "mlir::comm::createCommExplodeSplitsPass()"; +} def ArithRaisingPass : Pass<"arith-raise"> { let summary = "Raise Arith to mhlo"; diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir index 4f58c2822..eeb20847a 100644 --- a/test/comm_tests/refactor_test.mlir +++ b/test/comm_tests/refactor_test.mlir @@ -13,6 +13,8 @@ module { comm.split { %msg2 = comm.simple_msg f32 comm.branch [1] { + %lim = stablehlo.constant dense<5> : tensor + %step = stablehlo.constant dense<1> : tensor comm.join } comm.branch [4] { From 4359675d29d46f9ad28b46cc5e4619590d438531 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 29 Jan 2025 14:32:52 -0600 Subject: [PATCH 13/20] Extended dialect, added branch simplification pass --- src/enzyme_ad/jax/BUILD | 14 ++- .../jax/Dialects/Comm/CommDialect.cpp | 1 - src/enzyme_ad/jax/Dialects/Comm/CommDialect.h | 2 + .../jax/Dialects/Comm/CommDialect.td | 67 +++++++++-- .../jax/Dialects/Comm/CommInterfaces.cpp | 6 + src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp | 47 +++++++- .../jax/Passes/CommExplodeSplits.cpp | 108 +++++++++++++++--- test/comm_tests/refactor_test.mlir | 46 ++++---- 8 files changed, 241 insertions(+), 50 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index a371d66b4..5404346c2 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -117,17 +117,23 @@ gentbl_cc_library( ["-gen-op-defs", "-dialect=comm"], "Dialects/Comm/CommOps.cpp.inc", ),( - ["-gen-attrdef-decls", "--attrdefs-dialect=comm"], + ["-gen-attrdef-decls", "-attrdefs-dialect=comm"], "Dialects/Comm/CommAttrs.h.inc", ),( - ["-gen-attrdef-defs", "--attrdefs-dialect=comm"], + ["-gen-attrdef-defs", "-attrdefs-dialect=comm"], "Dialects/Comm/CommAttrs.cpp.inc", ),( - ["-gen-typedef-decls", "--typedefs-dialect=comm"], + ["-gen-typedef-decls", "-typedefs-dialect=comm"], "Dialects/Comm/CommTypes.h.inc", ), ( - ["-gen-typedef-defs", "--typedefs-dialect=comm"], + ["-gen-typedef-defs", "-typedefs-dialect=comm"], "Dialects/Comm/CommTypes.cpp.inc", + ),( + ["-gen-op-interface-decls"], + "Dialects/Comm/CommInterfaces.h.inc", + ), ( + ["-gen-op-interface-defs"], + "Dialects/Comm/CommInterfaces.cpp.inc", ), ], td_file = "Dialects/Comm/CommDialect.td", diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp index a560f05a1..ac7886a3c 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp @@ -22,7 +22,6 @@ void CommunicationDialect::initialize() { >(); } - #define GET_ATTRDEF_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h index f15c4f263..33d47ae19 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h @@ -22,6 +22,8 @@ class SplitMemberOp : public OpTrait::TraitBase { #define GET_TYPEDEF_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.h.inc" +#include "src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.h.inc" + #define GET_ATTRDEF_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.h.inc" diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 3fd80abed..1d72b23a5 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -9,6 +9,7 @@ def CommunicationDialect : Dialect { let summary = "A prototype dialect for various communication ops"; let description = [{}]; let cppNamespace = "::mlir::comm"; + let useDefaultTypePrinterParser = 1; } // Dialect inheritence shortcuts @@ -30,12 +31,11 @@ def DeviceIdType : CommType<"DeviceId", "device_id"> { } def MessageTokenType : CommType<"MessageToken", "token"> { let summary = "Represents a consumable message token"; - // let parameters = (ins "Type":$msg_type); - + let mnemonic = "msg_token"; } /* -* Dialect traits +* Dialect traits and interfaces */ def CommSplitMemberOpTrait : NativeOpTrait<"SplitMemberOp", /*traits=*/[], @@ -51,6 +51,19 @@ def CommSplitMemberOpTrait : NativeOpTrait<"SplitMemberOp", >{ let cppNamespace = "::mlir::comm"; } + +def CommMessage : OpInterface<"CommMessage"> { + let cppNamespace = "::mlir::comm"; + let methods = [ + InterfaceMethod<[{ + Returns what type this message takes as inputs + }], "mlir::Type", "getInputType">, + InterfaceMethod<[{ + Returns what type will result from recieving this message + }], "mlir::Type", "getOutputType"> + ]; +} + /* * Dialect Ops */ @@ -105,7 +118,7 @@ def CommSplit : CommOp<"split", traits = [SingleBlock, NoTerminator]> { // Add some convenience getters to hide the mess around having a declarations region let extraClassDeclaration = [{ auto getMessages() { - return getDeclarations().getOps<::mlir::comm::CommSimpleMessage>(); + return getDeclarations().getOps<::mlir::comm::CommMessage>(); } auto getBranches() { return getDeclarations().getOps<::mlir::comm::CommBranch>(); @@ -122,16 +135,56 @@ def CommBranch : CommOp<"branch", traits = [CommSplitMemberOpTrait]> { }]; } +def CommSend: CommOp<"send"> { + let summary = "An op to fulfill (part of) a messages input."; + let arguments = (ins MessageTokenType:$token, AnyType:$data); + let results = (outs ); + let assemblyFormat = [{ + attr-dict $token $data `:` type($data) + }]; + let extraClassDeclaration = [{ + CommSimpleMessage getMessage(); + }]; + let hasVerifier = 1; +} + +def CommRecv: CommOp<"recv"> { + let summary = "An op that blocks and returns the messages output"; + let arguments = (ins MessageTokenType:$token); + let results = (outs AnyType:$data); + let assemblyFormat = [{ + attr-dict $token `:` type($data) + }]; +} + +/* +* Different types of message ops +*/ +// Base class for messages +class CommMessageBase extra_traits = []>: CommOp, CommSplitMemberOpTrait]>; + // Message types. In the future we will likely want to have a common base class -def CommSimpleMessage: CommOp<"simple_msg", traits = [CommSplitMemberOpTrait]> { +def CommSimpleMessage: CommMessageBase<"simple_msg"> { let summary = "A simple single-usage, one-way message token"; let arguments = (ins - TypeAttr:$msg_type + TypeAttr:$data_type ); let results = (outs MessageTokenType:$token ); let assemblyFormat = [{ - attr-dict $msg_type + attr-dict $data_type }]; +} + +def CommMultiplexMessage: CommMessageBase<"multiplex_msg"> { + let summary = "A phi node-like message that allows the compiler to choose from any of the input messages"; + let arguments = (ins TypeAttr:$data_type, Variadic:$in_tokens); + let results = (outs + MessageTokenType:$out_token + ); + let assemblyFormat = [{ + attr-dict $data_type $in_tokens + }]; + let hasVerifier = 1; } \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp new file mode 100644 index 000000000..b6fe27982 --- /dev/null +++ b/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp @@ -0,0 +1,6 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" + +using namespace mlir; +using namespace mlir::comm; + +#include "src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp index e1df33ea1..cc1e36c8b 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp @@ -7,7 +7,7 @@ using namespace mlir; using namespace mlir::comm; - +// Split LogicalResult CommSplit::verify() { for(Operation &op : getDeclarations().getOps()){ // Check that all ops are allowable as members @@ -28,6 +28,51 @@ LogicalResult CommSplit::verify() { } return success(); } +// CommSend +LogicalResult CommSend::verify(){ + auto op = getToken().getDefiningOp(); + if(!isa(op)) return emitError("can only send to tokens from simple messages"); + return success(); +} + +CommSimpleMessage CommSend::getMessage(){ + return dyn_cast(getToken().getDefiningOp()); +} + +// CommSimpleMessage +mlir::Type CommSimpleMessage::getInputType() { + return getDataType(); +} + +mlir::Type CommSimpleMessage::getOutputType() { + return getDataType(); +} + +// CommMultiplexMessage +LogicalResult CommMultiplexMessage::verify() { + for (mlir::Value input_token : getInTokens()) { + auto input_op = input_token.getDefiningOp(); + if(CommMessage input_msg = dyn_cast(input_op)){ + // check that the data types of the input message and this message match + if(input_msg.getOutputType() != getDataType()){ + return emitError("includes message with return type different than declared"); + } + } else { + // TODO write verification to ensure all tokens are defined by messages only + return input_op->emitError("message tokens should only be defined by message declarations"); + } + } + return success(); +} + +mlir::Type CommMultiplexMessage::getInputType() { + // cannot send to a multiplex message! + return NoneType(); +} + +mlir::Type CommMultiplexMessage::getOutputType() { + return getDataType(); +} #define GET_OP_CLASSES #include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp index d7b1f516d..868a71aca 100644 --- a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp +++ b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp @@ -2,8 +2,12 @@ #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" +#include "mlir/IR/Builders.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/DenseMap.h" + +#include #define DEBUG_TYPE "enzyme" @@ -16,6 +20,40 @@ using namespace mlir::enzyme; // one of the upstream includes we need is wrapped namespace { struct CommExplodeSplits : public CommExplodeSplitsBase { + /** + * Creates a multiplexmessage wrapping this one and sets all recieves from this + * message to the multiplex. + * + */ + static CommMultiplexMessage createMultiplex(OpBuilder &builder, CommMessage msg) { + OpBuilder::InsertionGuard insert_guard(builder); + + // Find all recieve points + llvm::SmallVector receiving_uses; + for(auto &use : msg->getUses()) { + Operation* user = use.getOwner(); + + if (!isa(user)){ + // Complex messages count as receivers too, not just recv statements + receiving_uses.push_back(&use); + } + } + + // Create multiplex just after this one + builder.setInsertionPointAfter(msg); + + Type out_tok_type = MessageTokenType::get(msg.getContext()); + ValueRange in_tokens = (ValueRange({msg->getOpResult(0)})); + TypeAttr out_data_type = TypeAttr::get(msg.getOutputType()); + CommMultiplexMessage plex = builder.create(msg.getLoc(), out_tok_type, out_data_type, in_tokens); + + for(auto use : receiving_uses){ + use->assign(plex.getOutToken()); + } + return plex; + } + + /** * After a branch is split and contains only one device, look through it for * split nodes and inline them. @@ -67,7 +105,8 @@ struct CommExplodeSplits : public CommExplodeSplitsBase { static void explodeSplit(CommSplit split) { for (CommBranch branch : llvm::to_vector(split.getBranches())) { - llvm::dbgs() << "Running on branch " << branch << "\n"; + int n_clones = branch.getDeviceIds().size(); + // Look for any subsplits and move their message declarations up branch.walk([&](CommSplit subsplit) { for (auto msg : llvm::to_vector(subsplit.getMessages())) { @@ -75,34 +114,74 @@ struct CommExplodeSplits : public CommExplodeSplitsBase { } }); + // Don't care about branches with only one device + if(n_clones == 1) { + inlineSplits(branch); + continue; + } + + // Look for any send statements that appear in this branch in a multi-device + // context. When we clone the branches, we will want to transform their simple messages + // into multiplex messages. Since sub branches have already been exploded they can be + // excluded from consideration. + // mapping of tokens to their multiplex op + OpBuilder builder(branch->getContext()); + llvm::DenseMap, CommMultiplexMessage> tok_to_plex; + split->walk([&](Operation *op){ + if(isa(op) && op != split) { + // sub branches exploded already + return WalkResult::skip(); + } + if(CommSend send = dyn_cast(op)){ + if(!tok_to_plex.contains(send.getToken())){ + CommMultiplexMessage multiplex = createMultiplex(builder, send.getMessage()); + tok_to_plex.insert(std::make_pair(send.getToken(), multiplex)); + } + } + return WalkResult::advance(); + }); + // Copy the branch for each additional device. For the last one we can // just mutate instead of cloning. - for (int i = 0; i < branch.getDeviceIds().size() - 1; i++) { - llvm::dbgs() << "Creating deep clone of branch\n"; + for (int i = 0; i < n_clones - 1; i++) { CommBranch cloned = branch.clone(); - llvm::dbgs() << "Setting branch ids\n"; cloned.setDeviceIds({branch.getDeviceIds()[i]}); - llvm::dbgs() << "Moving branch into split block at loc " - << branch.getLoc() << "\n"; - // can't use insertBefore/After immediately since it segfaults trying to - // unlink. + + // can't use insertBefore/After immediately since it segfaults trying to unlink- need to add to a block first branch->getBlock()->push_back(cloned); cloned->moveBefore(branch); inlineSplits(cloned); + + // Walk over all sends and if they have been mapped to a multiplex replace them + llvm::DenseMap, TypedValue> token_replacements; + cloned.walk([&](CommSend send) { + auto orig_token = send.getToken(); + if(token_replacements.contains(orig_token)){ + // if we've already cloned the message for this branch use the clone + send.getTokenMutable().assign(token_replacements[orig_token]); + } else if (tok_to_plex.contains(orig_token)){ + // if this token needs to be multiplexed, create a clone of its message + // and add to the multiplex + + auto orig_msg = orig_token.getDefiningOp(); + auto cloned_msg = llvm::cast(orig_msg->clone()); + orig_msg->getBlock()->push_back(cloned_msg); + cloned_msg->moveAfter(orig_token.getDefiningOp()); + auto new_token = cloned_msg.getToken(); + tok_to_plex[orig_token].getInTokensMutable().append({new_token}); + send.getTokenMutable().assign(new_token); + token_replacements[orig_token] = new_token; + } + }); } branch.setDeviceIds({branch.getDeviceIds().back()}); inlineSplits(branch); - - llvm::dbgs() << "Done with current branch\n"; } - - llvm::dbgs() << "Done with all branches\n"; } void runOnOperation() override { - llvm::dbgs() << "Running pass on op " << getOperation() << "\n"; getOperation()->walk(explodeSplit); } }; @@ -112,7 +191,6 @@ struct CommExplodeSplits : public CommExplodeSplitsBase { namespace mlir { namespace comm { std::unique_ptr createCommExplodeSplitsPass() { - llvm::dbgs() << "creating explode pass\n"; return std::make_unique(); } } // namespace comm diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir index eeb20847a..35b5fc279 100644 --- a/test/comm_tests/refactor_test.mlir +++ b/test/comm_tests/refactor_test.mlir @@ -2,48 +2,50 @@ module { - func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { + func.func @main(%a : tensor<2x2xf32>) -> tensor { comm.foo comm.split { %msg = comm.simple_msg tensor<2x2xf32> + %msg3 = comm.simple_msg tensor<2x2xf32> comm.branch [1, 4] { - ^start: comm.foo comm.split { - %msg2 = comm.simple_msg f32 + %msg2 = comm.simple_msg tensor comm.branch [1] { - %lim = stablehlo.constant dense<5> : tensor %step = stablehlo.constant dense<1> : tensor + comm.send %msg2 %step : tensor comm.join } comm.branch [4] { - comm.foo + %start = stablehlo.constant dense<0> : tensor + %lim = stablehlo.constant dense<5> : tensor + %step = comm.recv %msg2 : tensor + %w, %z = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor + cond { + %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %9737 : tensor + } do { + %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> + %ni = stablehlo.add %iterArg_0, %step : tensor + stablehlo.return %next, %ni : tensor<2x2xf32>, tensor + } + comm.send %msg %w : tensor<2x2xf32> comm.join } - } + } + %tens = comm.recv %msg : tensor<2x2xf32> + // .. do something with tensor and then rebroadcast it + comm.send %msg3 %tens : tensor<2x2xf32> comm.join } comm.branch [2] { ^start: + %tens = comm.recv %msg3 : tensor<2x2xf32> comm.join } } - %start = stablehlo.constant dense<0> : tensor - - %lim = stablehlo.constant dense<5> : tensor - - %step = stablehlo.constant dense<1> : tensor - - %w:2 = stablehlo.while(%iterArg = %a, %iterArg_0 = %start) : tensor<2x2xf32>, tensor - cond { - %9737 = stablehlo.compare LT, %iterArg_0, %lim, SIGNED : (tensor, tensor) -> tensor - stablehlo.return %9737 : tensor - } do { - %next = stablehlo.add %iterArg, %iterArg : tensor<2x2xf32> - %ni = stablehlo.add %iterArg_0, %step : tensor - stablehlo.return %next, %ni : tensor<2x2xf32>, tensor - } - return %w#0 : tensor<2x2xf32> + %tmp = stablehlo.constant dense<0> : tensor + return %tmp : tensor } } \ No newline at end of file From f11f1224a7d8734fa5019f5439e4556bac93f1d0 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 29 Jan 2025 15:36:54 -0600 Subject: [PATCH 14/20] Remove multiplex pass --- src/enzyme_ad/jax/Passes/CommDeplex.cpp | 42 +++++++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 2 ++ src/enzyme_ad/jax/Passes/Passes.td | 13 ++++++-- 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/CommDeplex.cpp diff --git a/src/enzyme_ad/jax/Passes/CommDeplex.cpp b/src/enzyme_ad/jax/Passes/CommDeplex.cpp new file mode 100644 index 000000000..b7dd0d450 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CommDeplex.cpp @@ -0,0 +1,42 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::comm; +using namespace enzyme; using namespace mlir::enzyme; // one of the upstream includes we need is wrapped in this namespace + +namespace { +struct CommDeplex : public CommDeplexBase { + + /** + * Reassigns each use of this multiplex's token to one of the contributing tokens. + * + * TODO: this can potentially be a complex decision based on device load, communication + * latency, potential for removing communcations/computations outright, etc. + */ + static void chooseMultiplexMapping(CommMultiplexMessage plex) { + plex.getOutToken().replaceAllUsesWith(plex.getInTokens().front()); + } + + + void runOnOperation() override { + mlir::Operation* op = getOperation(); + op->walk([](CommMultiplexMessage plex){ + chooseMultiplexMapping(plex); + plex.erase(); + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace comm { +std::unique_ptr createCommDeplexPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 977d7aaac..3ac9491f9 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createPrintPass(); namespace comm { std::unique_ptr createCommRemoveFooPass(); std::unique_ptr createCommExplodeSplitsPass(); +std::unique_ptr createCommDeplexPass(); } } // namespace mlir @@ -94,5 +95,6 @@ static void regsiterenzymeXLAPasses() { registerEnzymeHLOUnrollPass(); registerCommRemoveFooPass(); registerCommExplodeSplitsPass(); + registerCommDeplexPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 961d82d57..fc5d3fda7 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -12,7 +12,7 @@ include "mlir/Pass/PassBase.td" -def CommRemoveFoo : Pass<"remove-comm-foo"> { +def CommRemoveFoo : Pass<"comm-remove-foo"> { let summary = "Removes all comm foo operations"; let dependentDialects = [ "comm::CommunicationDialect" @@ -20,7 +20,7 @@ def CommRemoveFoo : Pass<"remove-comm-foo"> { let constructor = "mlir::comm::createCommRemoveFooPass()"; } -def CommExplodeSplits : Pass<"explode-comm-splits"> { +def CommExplodeSplits : Pass<"comm-explode-splits"> { let summary = "Converts a (nested) split into a non-nested split with one branch per device"; let dependentDialects = [ "comm::CommunicationDialect" @@ -28,6 +28,15 @@ def CommExplodeSplits : Pass<"explode-comm-splits"> { let constructor = "mlir::comm::createCommExplodeSplitsPass()"; } +def CommDeplex : Pass<"comm-deplex"> { + let summary = "Removes multiplex messages by replacing each use of the output token with one of the input tokens"; + let dependentDialects = [ + "comm::CommunicationDialect" + ]; + let constructor = "mlir::comm::createCommDeplexPass()"; +} + + def ArithRaisingPass : Pass<"arith-raise"> { let summary = "Raise Arith to mhlo"; let dependentDialects = [ From 3597912b81b69c5a4871661362dc9031e75cb017 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 29 Jan 2025 16:46:33 -0600 Subject: [PATCH 15/20] Weak Dead Communication Elimination pass --- .../jax/Dialects/Comm/CommDialect.td | 7 +- src/enzyme_ad/jax/Passes/CommDeplex.cpp | 2 +- .../jax/Passes/CommExplodeSplits.cpp | 2 +- .../jax/Passes/CommRemoveDeadMessages.cpp | 74 +++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 2 + src/enzyme_ad/jax/Passes/Passes.td | 11 +++ 6 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 1d72b23a5..9288d044c 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -60,7 +60,10 @@ def CommMessage : OpInterface<"CommMessage"> { }], "mlir::Type", "getInputType">, InterfaceMethod<[{ Returns what type will result from recieving this message - }], "mlir::Type", "getOutputType"> + }], "mlir::Type", "getOutputType">, + InterfaceMethod<[{ + Returns the token handle to this message + }], "mlir::TypedValue", "getToken"> ]; } @@ -181,7 +184,7 @@ def CommMultiplexMessage: CommMessageBase<"multiplex_msg"> { let summary = "A phi node-like message that allows the compiler to choose from any of the input messages"; let arguments = (ins TypeAttr:$data_type, Variadic:$in_tokens); let results = (outs - MessageTokenType:$out_token + MessageTokenType:$token ); let assemblyFormat = [{ attr-dict $data_type $in_tokens diff --git a/src/enzyme_ad/jax/Passes/CommDeplex.cpp b/src/enzyme_ad/jax/Passes/CommDeplex.cpp index b7dd0d450..09db29a2f 100644 --- a/src/enzyme_ad/jax/Passes/CommDeplex.cpp +++ b/src/enzyme_ad/jax/Passes/CommDeplex.cpp @@ -18,7 +18,7 @@ struct CommDeplex : public CommDeplexBase { * latency, potential for removing communcations/computations outright, etc. */ static void chooseMultiplexMapping(CommMultiplexMessage plex) { - plex.getOutToken().replaceAllUsesWith(plex.getInTokens().front()); + plex.getToken().replaceAllUsesWith(plex.getInTokens().front()); } diff --git a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp index 868a71aca..b3219bbf7 100644 --- a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp +++ b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp @@ -48,7 +48,7 @@ struct CommExplodeSplits : public CommExplodeSplitsBase { CommMultiplexMessage plex = builder.create(msg.getLoc(), out_tok_type, out_data_type, in_tokens); for(auto use : receiving_uses){ - use->assign(plex.getOutToken()); + use->assign(plex.getToken()); } return plex; } diff --git a/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp new file mode 100644 index 000000000..4ce62cc33 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp @@ -0,0 +1,74 @@ +#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::comm; +using namespace enzyme; using namespace mlir::enzyme; // one of the upstream includes we need is wrapped in this namespace + +namespace { +struct CommRemoveDeadMessages : public CommRemoveDeadMessagesBase { + + static bool is_live(CommMessage msg, llvm::DenseSet &live_set, llvm::DenseSet &dead_set){ + if (live_set.contains(msg)) return true; + if (dead_set.contains(msg)) return false; + + bool live = false; + for(auto user : msg->getUsers()){ + if(isa(user)) { + live = true; + break; + } + else if(CommMessage using_msg = dyn_cast(user)) { + if(is_live(using_msg, live_set, dead_set)){ + live = true; + break; + } + } else if(!isa(user)){ + // Unhandled case, assume live + live = true; + break; + } + } + + if(live) { + live_set.insert(msg); + return true; + } else { + dead_set.insert(msg); + return false; + } + } + + void runOnOperation() override { + mlir::Operation* op = getOperation(); + + llvm::DenseSet live_messages; + llvm::DenseSet dead_messages; + + op->walk([&](CommMessage msg){ + if(!is_live(msg, live_messages, dead_messages)){ + for(auto user: llvm::to_vector(msg->getUsers())){ + // Erase only the sends- the other messages need to be walked over so their own cleanup triggers + if(isa(user)){ + user->erase(); + } + } + // Todo: with multiple messages, will this cause exceptions when we go to erase those? + msg->erase(); + } + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace comm { +std::unique_ptr createCommRemoveDeadMessagesPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 3ac9491f9..1d3743e0c 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -29,6 +29,7 @@ namespace comm { std::unique_ptr createCommRemoveFooPass(); std::unique_ptr createCommExplodeSplitsPass(); std::unique_ptr createCommDeplexPass(); +std::unique_ptr createCommRemoveDeadMessagesPass(); } } // namespace mlir @@ -96,5 +97,6 @@ static void regsiterenzymeXLAPasses() { registerCommRemoveFooPass(); registerCommExplodeSplitsPass(); registerCommDeplexPass(); + registerCommRemoveDeadMessagesPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index fc5d3fda7..6dfe8d78a 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -36,6 +36,17 @@ def CommDeplex : Pass<"comm-deplex"> { let constructor = "mlir::comm::createCommDeplexPass()"; } +def CommRemoveDeadMessages : Pass<"comm-remove-dead-messages"> { + let summary = [{ + Removes messages and senders with no receivers. + This does NOT perform a full analog to DCE- this will leave in useless communications if they + have any dependent recieve operation, regardless of whether the result of the receive is a dead value. + }]; + let dependentDialects = [ + "comm::CommunicationDialect" + ]; + let constructor = "mlir::comm::createCommRemoveDeadMessagesPass()"; +} def ArithRaisingPass : Pass<"arith-raise"> { let summary = "Raise Arith to mhlo"; From 1b7493acc418deeb502e6b30a93911ac29029870 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 17 Feb 2025 15:50:02 -0600 Subject: [PATCH 16/20] Remove dummy "foo" op --- .../jax/Dialects/Comm/CommDialect.td | 9 ------ src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp | 28 ------------------- src/enzyme_ad/jax/Passes/Passes.td | 7 ----- test/comm_tests/basic_test.mlir | 4 +-- test/comm_tests/branch_duplicated_device.mlir | 5 ---- test/comm_tests/branch_out_of_split.mlir | 4 --- test/comm_tests/extra_op_in_split.mlir | 6 ---- test/comm_tests/refactor_test.mlir | 3 -- 8 files changed, 1 insertion(+), 65 deletions(-) delete mode 100644 src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td index 9288d044c..add00882b 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td @@ -71,15 +71,6 @@ def CommMessage : OpInterface<"CommMessage"> { * Dialect Ops */ -def CommFoo : CommOp<"foo"> { - let summary = "do-nothing test op"; - let arguments = (ins ); - let results = (outs ); - let assemblyFormat = [{ - attr-dict - }]; -} - // Return, for end of split blocks. We may just be able to use return- lets see if there's any special // semantics we want join to have def CommJoin : CommOp<"join", traits = [Terminator]> { diff --git a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp b/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp deleted file mode 100644 index 7690b1082..000000000 --- a/src/enzyme_ad/jax/Passes/CommRemoveFoo.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" -#include "src/enzyme_ad/jax/Passes/Passes.h" - -#define DEBUG_TYPE "enzyme" - -namespace mlir { -namespace enzyme { -#define GEN_PASS_DEF_COMMREMOVEFOO -#include "src/enzyme_ad/jax/Passes/Passes.h.inc" -} // namespace enzyme -} // namespace mlir - -using namespace mlir; -using namespace mlir::comm; -using namespace enzyme; -using namespace mlir::enzyme; // one of the upstream includes we need is wrapped - // in this namespace - -namespace { -struct CommRemoveFoo : public enzyme::impl::CommRemoveFooBase { - - void runOnOperation() override { - mlir::Operation *op = getOperation(); - op->walk([](CommFoo foop) { foop->erase(); }); - } -}; - -} // end anonymous namespace \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 6adfb00ed..57baddac8 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -11,13 +11,6 @@ include "mlir/Pass/PassBase.td" -def CommRemoveFoo : Pass<"comm-remove-foo"> { - let summary = "Removes all comm foo operations"; - let dependentDialects = [ - "comm::CommunicationDialect" - ]; -} - def CommExplodeSplits : Pass<"comm-explode-splits"> { let summary = "Converts a (nested) split into a non-nested split with one branch per device"; let dependentDialects = [ diff --git a/test/comm_tests/basic_test.mlir b/test/comm_tests/basic_test.mlir index 743ddc878..2224ae317 100644 --- a/test/comm_tests/basic_test.mlir +++ b/test/comm_tests/basic_test.mlir @@ -1,15 +1,13 @@ -// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s +// RUN: enzymexlamlir-opt --enzyme-hlo-unroll %s | FileCheck %s module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.foo %msg = comm.simple_msg tensor<2x2xf32> comm.split {} { branch (1, 4) { ^start: - comm.foo comm.join } branch (2) { diff --git a/test/comm_tests/branch_duplicated_device.mlir b/test/comm_tests/branch_duplicated_device.mlir index b9d17de3d..e06678fd2 100644 --- a/test/comm_tests/branch_duplicated_device.mlir +++ b/test/comm_tests/branch_duplicated_device.mlir @@ -1,21 +1,16 @@ -// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s - module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.foo comm.split { %msg = comm.simple_msg tensor<2x2xf32> comm.branch [1, 4] { ^start: - comm.foo comm.split { comm.branch [1] { comm.join } comm.branch [4] { - comm.foo comm.join } } diff --git a/test/comm_tests/branch_out_of_split.mlir b/test/comm_tests/branch_out_of_split.mlir index 38ebf5ef3..76439cd1f 100644 --- a/test/comm_tests/branch_out_of_split.mlir +++ b/test/comm_tests/branch_out_of_split.mlir @@ -1,21 +1,17 @@ -// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.foo comm.split { %msg = comm.simple_msg tensor<2x2xf32> comm.branch [1, 4] { ^start: - comm.foo comm.split { comm.branch [1] { comm.join } comm.branch [4] { - comm.foo comm.join } } diff --git a/test/comm_tests/extra_op_in_split.mlir b/test/comm_tests/extra_op_in_split.mlir index c0e0dd6ab..f21632844 100644 --- a/test/comm_tests/extra_op_in_split.mlir +++ b/test/comm_tests/extra_op_in_split.mlir @@ -1,24 +1,18 @@ -// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s - module { func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> { - comm.foo comm.split { %msg = comm.simple_msg tensor<2x2xf32> comm.branch [1, 4] { ^start: - comm.foo comm.split { comm.branch [1] { comm.join } comm.branch [4] { - comm.foo comm.join } - comm.foo } comm.join } diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir index 35b5fc279..cc29edecc 100644 --- a/test/comm_tests/refactor_test.mlir +++ b/test/comm_tests/refactor_test.mlir @@ -1,15 +1,12 @@ -// RUN: enzymexlamlir-opt --remove-comm-foo --enzyme-hlo-unroll %s | FileCheck %s module { func.func @main(%a : tensor<2x2xf32>) -> tensor { - comm.foo comm.split { %msg = comm.simple_msg tensor<2x2xf32> %msg3 = comm.simple_msg tensor<2x2xf32> comm.branch [1, 4] { - comm.foo comm.split { %msg2 = comm.simple_msg tensor comm.branch [1] { From b1659b7906d2fad46cf82640a7b83d23345dcf4d Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 17 Feb 2025 15:57:17 -0600 Subject: [PATCH 17/20] Consolidate dialect directories --- src/enzyme_ad/jax/BUILD | 28 +++++++++--------- .../jax/Dialect/Comm/CommDialect.cpp | 29 +++++++++++++++++++ .../{Dialects => Dialect}/Comm/CommDialect.h | 10 +++---- .../{Dialects => Dialect}/Comm/CommDialect.td | 0 .../jax/Dialect/Comm/CommInterfaces.cpp | 6 ++++ .../{Dialects => Dialect}/Comm/CommOps.cpp | 4 +-- .../jax/Dialects/Comm/CommDialect.cpp | 29 ------------------- .../jax/Dialects/Comm/CommInterfaces.cpp | 6 ---- src/enzyme_ad/jax/Passes/CommDeplex.cpp | 2 +- .../jax/Passes/CommExplodeSplits.cpp | 2 +- .../jax/Passes/CommRemoveDeadMessages.cpp | 2 +- src/enzyme_ad/jax/Passes/Passes.h | 2 +- src/enzyme_ad/jax/RegistryUtils.cpp | 2 +- src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 2 +- 14 files changed, 62 insertions(+), 62 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp rename src/enzyme_ad/jax/{Dialects => Dialect}/Comm/CommDialect.h (76%) rename src/enzyme_ad/jax/{Dialects => Dialect}/Comm/CommDialect.td (100%) create mode 100644 src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp rename src/enzyme_ad/jax/{Dialects => Dialect}/Comm/CommOps.cpp (95%) delete mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp delete mode 100644 src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 23295c7e5..ee1f88c8a 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -112,7 +112,7 @@ gentbl_cc_library( td_library( name = "CommDialectFiles", srcs = [ - "Dialects/Comm/CommDialect.td", + "Dialect/Comm/CommDialect.td", ], deps = [ "@llvm-project//mlir:OpBaseTdFiles" @@ -124,37 +124,37 @@ gentbl_cc_library( name = "CommDialectIncGen", tbl_outs = [( ["-gen-dialect-decls", "-dialect=comm"], - "Dialects/Comm/CommDialect.h.inc", + "Dialect/Comm/CommDialect.h.inc", ), ( ["-gen-dialect-defs", "-dialect=comm"], - "Dialects/Comm/CommDialect.cpp.inc", + "Dialect/Comm/CommDialect.cpp.inc", ),( ["-gen-op-decls", "-dialect=comm"], - "Dialects/Comm/CommOps.h.inc", + "Dialect/Comm/CommOps.h.inc", ), ( ["-gen-op-defs", "-dialect=comm"], - "Dialects/Comm/CommOps.cpp.inc", + "Dialect/Comm/CommOps.cpp.inc", ),( ["-gen-attrdef-decls", "-attrdefs-dialect=comm"], - "Dialects/Comm/CommAttrs.h.inc", + "Dialect/Comm/CommAttrs.h.inc", ),( ["-gen-attrdef-defs", "-attrdefs-dialect=comm"], - "Dialects/Comm/CommAttrs.cpp.inc", + "Dialect/Comm/CommAttrs.cpp.inc", ),( ["-gen-typedef-decls", "-typedefs-dialect=comm"], - "Dialects/Comm/CommTypes.h.inc", + "Dialect/Comm/CommTypes.h.inc", ), ( ["-gen-typedef-defs", "-typedefs-dialect=comm"], - "Dialects/Comm/CommTypes.cpp.inc", + "Dialect/Comm/CommTypes.cpp.inc", ),( ["-gen-op-interface-decls"], - "Dialects/Comm/CommInterfaces.h.inc", + "Dialect/Comm/CommInterfaces.h.inc", ), ( ["-gen-op-interface-defs"], - "Dialects/Comm/CommInterfaces.cpp.inc", + "Dialect/Comm/CommInterfaces.cpp.inc", ), ], - td_file = "Dialects/Comm/CommDialect.td", + td_file = "Dialect/Comm/CommDialect.td", deps = [ ":CommDialectFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -278,8 +278,8 @@ cc_library( cc_library( name = "CommDialect", - srcs = glob(["Dialects/*.cpp", "Dialects/Comm/*.cpp"]), - hdrs = glob(["Dialects/*.h", "Dialects/Comm/*.h"]), + srcs = glob(["Dialects/*.cpp", "Dialect/Comm/*.cpp"]), + hdrs = glob(["Dialects/*.h", "Dialect/Comm/*.h"]), deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp new file mode 100644 index 000000000..422937634 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp @@ -0,0 +1,29 @@ +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp.inc" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::comm; + +void CommunicationDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h similarity index 76% rename from src/enzyme_ad/jax/Dialects/Comm/CommDialect.h rename to src/enzyme_ad/jax/Dialect/Comm/CommDialect.h index 33d47ae19..2beb34996 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.h +++ b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h @@ -5,7 +5,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Support/TypeID.h" #include "mlir/include/mlir/IR/DialectImplementation.h" -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h.inc" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" @@ -20,15 +20,15 @@ class SplitMemberOp : public OpTrait::TraitBase { } // namespace mlir::comm #define GET_TYPEDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.h.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h.inc" -#include "src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.h.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.h.inc" #define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.h.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.h.inc" #define GET_OP_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.h.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h.inc" template mlir::LogicalResult mlir::comm::SplitMemberOp::verifyTrait(Operation *op) { diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.td similarity index 100% rename from src/enzyme_ad/jax/Dialects/Comm/CommDialect.td rename to src/enzyme_ad/jax/Dialect/Comm/CommDialect.td diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp new file mode 100644 index 000000000..71c5628e7 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp @@ -0,0 +1,6 @@ +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" + +using namespace mlir; +using namespace mlir::comm; + +#include "src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp similarity index 95% rename from src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp rename to src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp index cc1e36c8b..375f43694 100644 --- a/src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp @@ -1,5 +1,5 @@ #include "mlir/Support/LogicalResult.h" -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" #include "llvm/ADT/DenseSet.h" @@ -75,4 +75,4 @@ mlir::Type CommMultiplexMessage::getOutputType() { } #define GET_OP_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" \ No newline at end of file +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp deleted file mode 100644 index ac7886a3c..000000000 --- a/src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.cpp.inc" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; -using namespace mlir::comm; - -void CommunicationDialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" - >(); - - addAttributes< -#define GET_ATTRDEF_LIST -#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" - >(); - addOperations< -#define GET_OP_LIST -#include "src/enzyme_ad/jax/Dialects/Comm/CommOps.cpp.inc" - >(); -} - -#define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommAttrs.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "src/enzyme_ad/jax/Dialects/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp b/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp deleted file mode 100644 index b6fe27982..000000000 --- a/src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp +++ /dev/null @@ -1,6 +0,0 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" - -using namespace mlir; -using namespace mlir::comm; - -#include "src/enzyme_ad/jax/Dialects/Comm/CommInterfaces.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/CommDeplex.cpp b/src/enzyme_ad/jax/Passes/CommDeplex.cpp index deda8211a..2c76a12a3 100644 --- a/src/enzyme_ad/jax/Passes/CommDeplex.cpp +++ b/src/enzyme_ad/jax/Passes/CommDeplex.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #define DEBUG_TYPE "enzyme" diff --git a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp index f89522d41..709351d88 100644 --- a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp +++ b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "mlir/IR/Builders.h" diff --git a/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp index 720f5aff4..bece01b86 100644 --- a/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp +++ b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #define DEBUG_TYPE "enzyme" diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 8fb633fdb..8d654dfb6 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -12,7 +12,7 @@ #include "mlir/Pass/Pass.h" #include -#include "src/enzyme_ad/jax/Dialects/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" namespace mlir { diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 08060b141..1b253160a 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -52,7 +52,7 @@ #include "shardy/dialect/sdy/ir/dialect.h" -#include "Dialects/Comm/CommDialect.h" +#include "Dialect/Comm/CommDialect.h" namespace mlir { namespace enzyme { void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 99b1e0c19..4a979f9cb 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -60,7 +60,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "Dialects/Comm/CommDialect.h" +#include "Dialect/Comm/CommDialect.h" From ddd375fc3f068d187bde410b17f7181566f9e1fd Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Mon, 17 Feb 2025 18:29:45 -0600 Subject: [PATCH 18/20] Restructure header files --- src/enzyme_ad/jax/BUILD | 10 +---- src/enzyme_ad/jax/Dialect/Comm/Comm.h | 8 ++++ .../jax/Dialect/Comm/CommDialect.cpp | 29 +++++-------- src/enzyme_ad/jax/Dialect/Comm/CommDialect.h | 37 ++-------------- src/enzyme_ad/jax/Dialect/Comm/CommDialect.td | 15 ++----- .../jax/Dialect/Comm/CommInterfaces.cpp | 2 +- src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp | 2 +- src/enzyme_ad/jax/Dialect/Comm/CommOps.h | 42 +++++++++++++++++++ src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp | 10 +++++ src/enzyme_ad/jax/Dialect/Comm/CommTypes.h | 16 +++++++ src/enzyme_ad/jax/Passes/CommDeplex.cpp | 2 +- .../jax/Passes/CommExplodeSplits.cpp | 2 +- .../jax/Passes/CommRemoveDeadMessages.cpp | 2 +- src/enzyme_ad/jax/Passes/Passes.td | 6 +-- src/enzyme_ad/jax/RegistryUtils.cpp | 2 +- 15 files changed, 104 insertions(+), 81 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialect/Comm/Comm.h create mode 100644 src/enzyme_ad/jax/Dialect/Comm/CommOps.h create mode 100644 src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp create mode 100644 src/enzyme_ad/jax/Dialect/Comm/CommTypes.h diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index ee1f88c8a..e48104738 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -134,12 +134,6 @@ gentbl_cc_library( ), ( ["-gen-op-defs", "-dialect=comm"], "Dialect/Comm/CommOps.cpp.inc", - ),( - ["-gen-attrdef-decls", "-attrdefs-dialect=comm"], - "Dialect/Comm/CommAttrs.h.inc", - ),( - ["-gen-attrdef-defs", "-attrdefs-dialect=comm"], - "Dialect/Comm/CommAttrs.cpp.inc", ),( ["-gen-typedef-decls", "-typedefs-dialect=comm"], "Dialect/Comm/CommTypes.h.inc", @@ -278,8 +272,8 @@ cc_library( cc_library( name = "CommDialect", - srcs = glob(["Dialects/*.cpp", "Dialect/Comm/*.cpp"]), - hdrs = glob(["Dialects/*.h", "Dialect/Comm/*.h"]), + srcs = glob(["Dialect/Comm/*.cpp"]), + hdrs = glob(["Dialect/Comm/*.h"]), deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/src/enzyme_ad/jax/Dialect/Comm/Comm.h b/src/enzyme_ad/jax/Dialect/Comm/Comm.h new file mode 100644 index 000000000..2ce86b672 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/Comm.h @@ -0,0 +1,8 @@ +#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMM_H +#define ENZYME_AD_JAX_DIALECTS_COMM_COMM_H + +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h" + +#endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp index 422937634..430f12fcd 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp @@ -1,29 +1,20 @@ -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp.inc" +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::comm; -void CommunicationDialect::initialize() { +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp.inc" + +void CommDialect::initialize() { addTypes< -#define GET_TYPEDEF_LIST -#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" - >(); + #define GET_TYPEDEF_LIST + #include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" + >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.cpp.inc" - >(); addOperations< -#define GET_OP_LIST -#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp.inc" - >(); + #define GET_OP_LIST + #include "src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp.inc" + >(); } - -#define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h index 2beb34996..06b9a0a63 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h +++ b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.h @@ -1,41 +1,10 @@ #ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H #define ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H -#include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" -#include "mlir/Support/TypeID.h" -#include "mlir/include/mlir/IR/DialectImplementation.h" -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h.inc" - -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Dialect.h" - -namespace mlir::comm { -template -class SplitMemberOp : public OpTrait::TraitBase { -public: - static LogicalResult verifyTrait(Operation *op); -}; -} // namespace mlir::comm - -#define GET_TYPEDEF_CLASSES -#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h.inc" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" -#include "src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.h.inc" - -#define GET_ATTRDEF_CLASSES -#include "src/enzyme_ad/jax/Dialect/Comm/CommAttrs.h.inc" - -#define GET_OP_CLASSES -#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h.inc" - -template -mlir::LogicalResult mlir::comm::SplitMemberOp::verifyTrait(Operation *op) { - if (!isa(op->getParentOp())) { - return op->emitOpError("must be located as immediate child of split op"); - } - return success(); -} +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h.inc" #endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.td b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.td index add00882b..53d40e46e 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/CommDialect.td +++ b/src/enzyme_ad/jax/Dialect/Comm/CommDialect.td @@ -4,7 +4,7 @@ include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/DialectBase.td" include "mlir/IR/Traits.td" -def CommunicationDialect : Dialect { +def CommDialect : Dialect { let name = "comm"; let summary = "A prototype dialect for various communication ops"; let description = [{}]; @@ -13,22 +13,15 @@ def CommunicationDialect : Dialect { } // Dialect inheritence shortcuts -class CommOp traits = []> : Op; -class CommAttr traits = []> : AttrDef{ - let mnemonic = mnemomic; -} -class CommType traits = []> : TypeDef { +class CommOp traits = []> : Op; + +class CommType traits = []> : TypeDef { let mnemonic = type_mnemonic; } /* * Dialect Types */ -def DeviceIdType : CommType<"DeviceId", "device_id"> { - let summary="Wrapper around int to specify a device from our device set"; - let parameters=(ins "unsigned":$id); - let assemblyFormat=[{`d` $id }]; -} def MessageTokenType : CommType<"MessageToken", "token"> { let summary = "Represents a consumable message token"; let mnemonic = "msg_token"; diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp index 71c5628e7..1e2a4437b 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp +++ b/src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h" using namespace mlir; using namespace mlir::comm; diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp index 375f43694..a2e63999e 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp +++ b/src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp @@ -1,5 +1,5 @@ #include "mlir/Support/LogicalResult.h" -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h" #include "llvm/ADT/DenseSet.h" diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommOps.h b/src/enzyme_ad/jax/Dialect/Comm/CommOps.h new file mode 100644 index 000000000..73922ed30 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/CommOps.h @@ -0,0 +1,42 @@ +/** + * Contains the includes for comm ops and interfaces. + * + * These includes are difficult to separate in this case due to template + * dependencies closely interlink the definitions of the SplitMemberOp trait, + * MessageOp interface, and CommSplit ops, so they are handled in the same file. + */ +#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMMOPS_H +#define ENZYME_AD_JAX_DIALECTS_COMM_COMMOPS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" + + +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h" + +namespace mlir::comm { +template +class SplitMemberOp : public OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op); +}; +} // namespace mlir::comm + +#include "src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h.inc" + +template +mlir::LogicalResult +mlir::comm::SplitMemberOp::verifyTrait(Operation *op) { + if (!isa(op->getParentOp())) { + return op->emitOpError("must be located as immediate child of split op"); + } + return success(); +} + +#endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp b/src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp new file mode 100644 index 000000000..afa552476 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp @@ -0,0 +1,10 @@ +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h" + + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc" \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/CommTypes.h b/src/enzyme_ad/jax/Dialect/Comm/CommTypes.h new file mode 100644 index 000000000..be5a666fb --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/CommTypes.h @@ -0,0 +1,16 @@ +#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMMTYPES_H +#define ENZYME_AD_JAX_DIALECTS_COMM_COMMTYPES_H + + +#include "mlir/Support/TypeID.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Attributes.h" + +#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" + +#define GET_TYPEDEF_CLASSES +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h.inc" + +#endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/CommDeplex.cpp b/src/enzyme_ad/jax/Passes/CommDeplex.cpp index 2c76a12a3..bcc138bba 100644 --- a/src/enzyme_ad/jax/Passes/CommDeplex.cpp +++ b/src/enzyme_ad/jax/Passes/CommDeplex.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #define DEBUG_TYPE "enzyme" diff --git a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp index 709351d88..4b14046ff 100644 --- a/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp +++ b/src/enzyme_ad/jax/Passes/CommExplodeSplits.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "mlir/IR/Builders.h" diff --git a/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp index bece01b86..ac3393513 100644 --- a/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp +++ b/src/enzyme_ad/jax/Passes/CommRemoveDeadMessages.cpp @@ -1,4 +1,4 @@ -#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #define DEBUG_TYPE "enzyme" diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 57baddac8..935e39dcf 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -14,14 +14,14 @@ include "mlir/Pass/PassBase.td" def CommExplodeSplits : Pass<"comm-explode-splits"> { let summary = "Converts a (nested) split into a non-nested split with one branch per device"; let dependentDialects = [ - "comm::CommunicationDialect" + "comm::CommDialect" ]; } def CommDeplex : Pass<"comm-deplex"> { let summary = "Removes multiplex messages by replacing each use of the output token with one of the input tokens"; let dependentDialects = [ - "comm::CommunicationDialect" + "comm::CommDialect" ]; } @@ -30,7 +30,7 @@ def CommRemoveDeadMessages : Pass<"comm-remove-dead-messages"> { let description = [{This does NOT perform a full analog to DCE- this will leave in useless communications if they have any dependent recieve operation, regardless of whether the result of the receive is a dead value.}]; let dependentDialects = [ - "comm::CommunicationDialect" + "comm::CommDialect" ]; } diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 1b253160a..d560e70d4 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -91,7 +91,7 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { registry.insert(); - registry.insert(); + registry.insert(); mlir::enzyme::registerXLAAutoDiffInterfaces(registry); From 9d7a25854e67ed41785433e6f674ccd55eaa195b Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 19 Feb 2025 14:23:22 -0600 Subject: [PATCH 19/20] (Incomplete) Lower messages to jit call pass --- src/enzyme_ad/jax/Dialect/Comm/Comm.cpp | 8 + src/enzyme_ad/jax/Dialect/Comm/Comm.h | 14 +- src/enzyme_ad/jax/Passes/CommLower.cpp | 186 ++++++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 10 ++ test/comm_tests/refactor_test.mlir | 8 + 5 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 src/enzyme_ad/jax/Dialect/Comm/Comm.cpp create mode 100644 src/enzyme_ad/jax/Passes/CommLower.cpp diff --git a/src/enzyme_ad/jax/Dialect/Comm/Comm.cpp b/src/enzyme_ad/jax/Dialect/Comm/Comm.cpp new file mode 100644 index 000000000..500c02469 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Comm/Comm.cpp @@ -0,0 +1,8 @@ +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" + +using namespace mlir::comm; + +llvm::ArrayRef mlir::comm::getOpDevices(mlir::Operation &op) { + auto parent_branch = op.getParentOfType(); + return parent_branch.getDeviceIds(); +} \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Comm/Comm.h b/src/enzyme_ad/jax/Dialect/Comm/Comm.h index 2ce86b672..d04ba3330 100644 --- a/src/enzyme_ad/jax/Dialect/Comm/Comm.h +++ b/src/enzyme_ad/jax/Dialect/Comm/Comm.h @@ -2,7 +2,19 @@ #define ENZYME_AD_JAX_DIALECTS_COMM_COMM_H #include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h" -#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h" #include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h" +#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h" + +// Utility functions + +namespace mlir::comm { + +/** + * Returns the device set of a given op. Should only be called on an op + * located within a branch. + */ +llvm::ArrayRef getOpDevices(mlir::Operation &op); + +} #endif \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/CommLower.cpp b/src/enzyme_ad/jax/Passes/CommLower.cpp new file mode 100644 index 000000000..40e6137f9 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CommLower.cpp @@ -0,0 +1,186 @@ +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#define DEBUG_TYPE "enzyme" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_COMMLOWER +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace comm; +using namespace enzyme; + +namespace { + +// TODO find better tag than static global +static int tagCounter = 0; + +std::string sendFuncName = "commSendF"; +std::string recvFuncName = "commRecvF"; + +struct LowerMessageOpsPattern + : public OpRewritePattern { + + LowerMessageOpsPattern(mlir::MLIRContext *ctx) : OpRewritePattern(ctx) {} + + static uint32_t getNewTag() { return tagCounter++; } + + static LogicalResult getStaticByteSize(mlir::Value v, unsigned &num_bytes) { + if(TensorType tt = dyn_cast(v.getType())){ + if(tt.hasStaticShape()){ + unsigned numel = tt.getNumElements(); + unsigned bitwidth = tt.getElementTypeBitWidth(); + assert(bitwidth % 8 == 0 && "Currently support integer byte multiple bitwidths"); + num_bytes = numel * bitwidth / 8; + return LogicalResult::success(); + } + } + return emitError(v.getLoc(), "Unsupported message type: currently only able to get message size for staticly sized tensors"); + } + + LogicalResult matchAndRewrite(CommSimpleMessage msg, + PatternRewriter &rewriter) const override { + llvm::dbgs() << "Pattern called \n"; + // Collect the send, recvs of the message + auto tok = msg.getToken(); + CommSend send; + llvm::SmallVector recvs; + + for (auto user : tok.getUsers()) { + if (CommSend s = dyn_cast(user)) { + if (send) { + return s.emitOpError("Multiply send ops defined for same token!"); + } + send = s; + } else if (CommRecv r = dyn_cast(user)) { + recvs.push_back(r); + } else { + return user->emitOpError("not a handled case: was this supposed to be " + "removed prior to lowering?"); + } + } + llvm::dbgs() << "C" << send << "\n"; + + // Check that there is only one sending device + auto send_devices = getOpDevices(*send); + llvm::dbgs() << "C2\n"; + if (send_devices.size() != 1) { + return send.emitOpError( + "should be scheduled on exactly one device when lowering"); + } + auto send_device = send_devices.front(); + rewriter.setInsertionPoint(msg.getParentSplit()); + mlir::Value send_device_val = rewriter.create(send.getLoc(), send_device, 32); + + + // Create a UID for the tag and other message attributes + uint32_t tag = getNewTag(); + mlir::Value tag_val = rewriter.create(msg.getLoc(), tag, 32); + + + llvm::dbgs() << "D\n"; + + unsigned message_size; + LogicalResult try_get_size = getStaticByteSize(send.getData(), message_size); + if(try_get_size.failed()) return try_get_size; + mlir::Value message_size_val = rewriter.create(send.getLoc(), message_size, 64); + + llvm::dbgs() << "E\n"; + + for (CommRecv recv : recvs) { + llvm::dbgs() << "I\n"; + // Likewise get single recieving device + auto recv_devices = getOpDevices(*recv); + if (recv_devices.size() != 1) { + return recv.emitOpError( + "should be scheduled on exactly one device when lowering"); + } + auto recv_device = recv_devices.front(); + llvm::dbgs() << "F\n"; + + rewriter.setInsertionPoint(msg.getParentSplit()); + mlir::Value recv_device_val = rewriter.create(recv.getLoc(), recv_device, 32); + llvm::dbgs() << "G\n"; + // Replace send with a JIT call to the send function for each receiver. + // TODO: make this async, maybe make this one function instead of one per recv + rewriter.setInsertionPoint(send); + rewriter.create( + send.getLoc(), + (mlir::TypeRange){}, + mlir::FlatSymbolRefAttr::get(send.getContext(), sendFuncName), + (mlir::ValueRange){recv_device_val, tag_val, send.getData(), message_size_val}, + nullptr, // Backend config (use default) + nullptr, // Operand layouts + nullptr, // result layouts + nullptr // output aliases + ); + llvm::dbgs() << "H\n"; + + // Create a recv call for the appropriate device + llvm::dbgs() << "H2\n"; + rewriter.setInsertionPoint(recv); + auto recv_call = rewriter.create( + recv.getLoc(), + recv->getResultTypes(), + mlir::FlatSymbolRefAttr::get(send.getContext(), recvFuncName), + (mlir::ValueRange){send_device_val, tag_val, message_size_val}, + nullptr, // Backend config (use default) + nullptr, // Operand layouts + nullptr, // result layouts + nullptr // output aliases + ); + rewriter.replaceOp(recv, recv_call); + + } + llvm::dbgs() << "J\n"; + llvm::dbgs() << *msg.getParentSplit()->getParentOp(); + + rewriter.eraseOp(send); + rewriter.eraseOp(msg); + llvm::dbgs() << "K\n"; + + return LogicalResult::success(); + } +}; + +struct CommLower : public enzyme::impl::CommLowerBase { + + /** + * Reassigns each use of this multiplex's token to one of the contributing + * tokens. + * + * TODO: this can potentially be a complex decision based on device load, + * communication latency, potential for removing communcations/computations + * outright, etc. + */ + static void chooseMultiplexMapping(CommMultiplexMessage plex) { + plex.getToken().replaceAllUsesWith(plex.getInTokens().front()); + } + + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(op->getContext()); + llvm::dbgs() << "A\n"; + patterns.add(op->getContext()); + llvm::dbgs() << "B\n"; + FrozenRewritePatternSet frozen(std::move(patterns)); + llvm::dbgs() << "Calling apply patterns\n"; + (void)mlir::applyPatternsAndFoldGreedily(op, frozen); + llvm::dbgs() << "\n\n final: \n" << *op; + + } +}; + +} // end anonymous namespace \ No newline at end of file diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 935e39dcf..d6a70bd74 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -34,6 +34,16 @@ def CommRemoveDeadMessages : Pass<"comm-remove-dead-messages"> { ]; } +def CommLower : Pass<"comm-lower"> { + let summary ="Lowers to the MPI dialect"; + let dependentDialects = [ + "comm::CommDialect", + "mpi::MPIDialect", + "enzymexla::EnzymeXLADialect", + "mlir::arith::ArithDialect" + ]; +} + def CanonicalizeLoopsPass : InterfacePass<"canonicalize-loops", "mlir::FunctionOpInterface"> { let summary = "Canonicalize loops"; diff --git a/test/comm_tests/refactor_test.mlir b/test/comm_tests/refactor_test.mlir index cc29edecc..69d35346b 100644 --- a/test/comm_tests/refactor_test.mlir +++ b/test/comm_tests/refactor_test.mlir @@ -1,6 +1,14 @@ module { + func.func @commSendF(%rank : i32, %tag: i32, %msg :!llvm.ptr<1>, %size :i64) { + return + } + + func.func @commRecvF(%rank : i32, %tag: i32, %size: i64) { + return + } + func.func @main(%a : tensor<2x2xf32>) -> tensor { comm.split { From d14ff1241151b6d93e5c542ba945bf2147209976 Mon Sep 17 00:00:00 2001 From: Egan Johnson Date: Wed, 19 Feb 2025 16:49:00 -0600 Subject: [PATCH 20/20] Delete vscode files --- .vscode/settings.json | 65 ------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 6387f84fe..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,65 +0,0 @@ -{ - "files.associations": { - "*.inc": "cpp", - "algorithm": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "string": "cpp", - "cmath": "cpp", - "typeinfo": "cpp", - "cstdlib": "cpp", - "limits": "cpp", - "new": "cpp", - "type_traits": "cpp", - "vector": "cpp", - "__verbose_abort": "cpp", - "array": "cpp", - "cstring": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "execution": "cpp", - "memory": "cpp", - "initializer_list": "cpp", - "iosfwd": "cpp", - "list": "cpp", - "stdexcept": "cpp", - "unordered_map": "cpp", - "variant": "cpp", - "atomic": "cpp", - "bit": "cpp", - "*.tcc": "cpp", - "compare": "cpp", - "concepts": "cpp", - "exception": "cpp", - "functional": "cpp", - "iterator": "cpp", - "memory_resource": "cpp", - "random": "cpp", - "tuple": "cpp", - "utility": "cpp", - "cctype": "cpp", - "clocale": "cpp", - "complex": "cpp", - "condition_variable": "cpp", - "ctime": "cpp", - "deque": "cpp", - "map": "cpp", - "set": "cpp", - "istream": "cpp", - "mutex": "cpp", - "numbers": "cpp", - "numeric": "cpp", - "optional": "cpp", - "ostream": "cpp", - "ratio": "cpp", - "semaphore": "cpp", - "shared_mutex": "cpp", - "sstream": "cpp", - "stop_token": "cpp", - "streambuf": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "thread": "cpp" - } -} \ No newline at end of file