Skip to content

WIP Prototype communcation dialect #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ cc_binary(
"//src/enzyme_ad/jax:RaisingTransformOps",
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:CommDialect",
"@stablehlo//:chlo_ops",
"@stablehlo//stablehlo/tests:check_ops",
"@shardy//shardy/dialect/sdy/ir:dialect",
Expand Down
63 changes: 63 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,56 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
)


td_library(
name = "CommDialectFiles",
srcs = [
"Dialect/Comm/CommDialect.td",
],
deps = [
"@llvm-project//mlir:OpBaseTdFiles"
]
)


gentbl_cc_library(
name = "CommDialectIncGen",
tbl_outs = [(
["-gen-dialect-decls", "-dialect=comm"],
"Dialect/Comm/CommDialect.h.inc",
), (
["-gen-dialect-defs", "-dialect=comm"],
"Dialect/Comm/CommDialect.cpp.inc",
),(
["-gen-op-decls", "-dialect=comm"],
"Dialect/Comm/CommOps.h.inc",
), (
["-gen-op-defs", "-dialect=comm"],
"Dialect/Comm/CommOps.cpp.inc",
),(
["-gen-typedef-decls", "-typedefs-dialect=comm"],
"Dialect/Comm/CommTypes.h.inc",
), (
["-gen-typedef-defs", "-typedefs-dialect=comm"],
"Dialect/Comm/CommTypes.cpp.inc",
),(
["-gen-op-interface-decls"],
"Dialect/Comm/CommInterfaces.h.inc",
), (
["-gen-op-interface-defs"],
"Dialect/Comm/CommInterfaces.cpp.inc",
),
],
td_file = "Dialect/Comm/CommDialect.td",
deps = [
":CommDialectFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:AttrTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles"
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
)

gentbl_cc_library(
name = "TransformOpsImplIncGen",
tbl_outs = [(
Expand Down Expand Up @@ -215,10 +265,22 @@ cc_library(
"@llvm-project//mlir:TransformDialectInterfaces",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":CommDialectIncGen",
":XLADerivatives",
],
)

cc_library(
name = "CommDialect",
srcs = glob(["Dialect/Comm/*.cpp"]),
hdrs = glob(["Dialect/Comm/*.h"]),
deps = [
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
":CommDialectIncGen",
],
)

td_library(
name = "ImplementationsCommonTdFiles",
srcs = [
Expand Down Expand Up @@ -392,6 +454,7 @@ cc_library(
"-Werror=unused-result",
],
deps = [
":CommDialect",
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
":EnzymeHLOPatternsIncGen",
Expand Down
8 changes: 8 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/Comm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h"

using namespace mlir::comm;

llvm::ArrayRef<int32_t> mlir::comm::getOpDevices(mlir::Operation &op) {
auto parent_branch = op.getParentOfType<CommBranch>();
return parent_branch.getDeviceIds();
}
20 changes: 20 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/Comm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMM_H
#define ENZYME_AD_JAX_DIALECTS_COMM_COMM_H

#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h"
#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h"
#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.h"

// Utility functions

namespace mlir::comm {

/**
* Returns the device set of a given op. Should only be called on an op
* located within a branch.
*/
llvm::ArrayRef<int32_t> getOpDevices(mlir::Operation &op);

}

#endif
20 changes: 20 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "src/enzyme_ad/jax/Dialect/Comm/Comm.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::comm;

#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.cpp.inc"

void CommDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "src/enzyme_ad/jax/Dialect/Comm/CommTypes.cpp.inc"
>();

addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.cpp.inc"
>();
}
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/CommDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H
#define ENZYME_AD_JAX_DIALECTS_COMM_COMMDIALECT_H

#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"

#include "src/enzyme_ad/jax/Dialect/Comm/CommDialect.h.inc"

#endif
177 changes: 177 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/CommDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/DialectBase.td"
include "mlir/IR/Traits.td"

def CommDialect : Dialect {
let name = "comm";
let summary = "A prototype dialect for various communication ops";
let description = [{}];
let cppNamespace = "::mlir::comm";
let useDefaultTypePrinterParser = 1;
}

// Dialect inheritence shortcuts
class CommOp<string name, list<Trait> traits = []> : Op<CommDialect, name, traits>;

class CommType<string name, string type_mnemonic, list<Trait> traits = []> : TypeDef<CommDialect, name, traits> {
let mnemonic = type_mnemonic;
}

/*
* Dialect Types
*/
def MessageTokenType : CommType<"MessageToken", "token"> {
let summary = "Represents a consumable message token";
let mnemonic = "msg_token";
}

/*
* Dialect traits and interfaces
*/
def CommSplitMemberOpTrait : NativeOpTrait<"SplitMemberOp",
/*traits=*/[],
/*extraOpDeclaration = */[{
mlir::comm::CommSplit getParentSplit();
}],
/*extraOpDefinition = */[{
mlir::comm::CommSplit $cppClass::getParentSplit(){
// Verifier checks that this is indeed of the correct type
return dyn_cast<mlir::comm::CommSplit>(getOperation()->getParentOp());
}
}]
>{
let cppNamespace = "::mlir::comm";
}

def CommMessage : OpInterface<"CommMessage"> {
let cppNamespace = "::mlir::comm";
let methods = [
InterfaceMethod<[{
Returns what type this message takes as inputs
}], "mlir::Type", "getInputType">,
InterfaceMethod<[{
Returns what type will result from recieving this message
}], "mlir::Type", "getOutputType">,
InterfaceMethod<[{
Returns the token handle to this message
}], "mlir::TypedValue<mlir::comm::MessageTokenType>", "getToken">
];
}

/*
* Dialect Ops
*/

// Return, for end of split blocks. We may just be able to use return- lets see if there's any special
// semantics we want join to have
def CommJoin : CommOp<"join", traits = [Terminator]> {
let summary = "Denotes the end of a split block, similar to ret for a function";
let arguments = (ins );
let results = (outs );
let assemblyFormat = [{
attr-dict
}];
}

def CommSplit : CommOp<"split", traits = [SingleBlock, NoTerminator]> {
let summary = "The highest level split node in the communication dialect.";
let description = [{
Takes in a definition of communication items and a list of split branches for devices to take.
Encoded as a single-block no-terminator region that consists only of branches and communcation token declarations.
Example syntax:
comm.split {
%1 = comm.simple_msg msg_type
comm.branch [1, 4] {
// ... comm branch region
}
comm.branch [2] {
// ... comm branch region
}
}
}];

let arguments = (ins ); // no inputs yet, encoded in the region
let regions = (region SizedRegion<1>:$declarations);
let results = (outs );

let assemblyFormat = [{
$declarations attr-dict
}];

let hasVerifier = 1;

// Add some convenience getters to hide the mess around having a declarations region
let extraClassDeclaration = [{
auto getMessages() {
return getDeclarations().getOps<::mlir::comm::CommMessage>();
}
auto getBranches() {
return getDeclarations().getOps<::mlir::comm::CommBranch>();
}
}];
}

def CommBranch : CommOp<"branch", traits = [CommSplitMemberOpTrait]> {
let summary = "Represents one branch that can be taken by a split node";
let arguments = (ins DenseI32ArrayAttr:$device_ids);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict $device_ids $region
}];
}

def CommSend: CommOp<"send"> {
let summary = "An op to fulfill (part of) a messages input.";
let arguments = (ins MessageTokenType:$token, AnyType:$data);
let results = (outs );
let assemblyFormat = [{
attr-dict $token $data `:` type($data)
}];
let extraClassDeclaration = [{
CommSimpleMessage getMessage();
}];
let hasVerifier = 1;
}

def CommRecv: CommOp<"recv"> {
let summary = "An op that blocks and returns the messages output";
let arguments = (ins MessageTokenType:$token);
let results = (outs AnyType:$data);
let assemblyFormat = [{
attr-dict $token `:` type($data)
}];
}

/*
* Different types of message ops
*/
// Base class for messages
class CommMessageBase<string name, list<Trait> extra_traits = []>: CommOp<name, traits = extra_traits # [DeclareOpInterfaceMethods<CommMessage>, CommSplitMemberOpTrait]>;

// Message types. In the future we will likely want to have a common base class
def CommSimpleMessage: CommMessageBase<"simple_msg"> {
let summary = "A simple single-usage, one-way message token";
let arguments = (ins
TypeAttr:$data_type
);
let results = (outs
MessageTokenType:$token
);
let assemblyFormat = [{
attr-dict $data_type
}];
}

def CommMultiplexMessage: CommMessageBase<"multiplex_msg"> {
let summary = "A phi node-like message that allows the compiler to choose from any of the input messages";
let arguments = (ins TypeAttr:$data_type, Variadic<MessageTokenType>:$in_tokens);
let results = (outs
MessageTokenType:$token
);
let assemblyFormat = [{
attr-dict $data_type $in_tokens
}];
let hasVerifier = 1;
}
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include "src/enzyme_ad/jax/Dialect/Comm/CommOps.h"

using namespace mlir;
using namespace mlir::comm;

#include "src/enzyme_ad/jax/Dialect/Comm/CommInterfaces.cpp.inc"
Loading