Skip to content

Commit 94bfd42

Browse files
authored
feat(mlir): if/cond (#1379)
1 parent 958197d commit 94bfd42

File tree

11 files changed

+355
-147
lines changed

11 files changed

+355
-147
lines changed

exla/c_src/exla/exla.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ static ErlNifFunc exla_funcs[] = {
721721
{"mlir_dynamic_update_slice", 4, mlir_dynamic_update_slice},
722722
{"mlir_reduce", 5, mlir_reduce},
723723
{"mlir_map", 4, mlir_map},
724+
{"mlir_if", 6, mlir_if},
724725
// XlaBuilder
725726
{"new_builder", 1, new_builder},
726727
{"create_sub_builder", 2, create_sub_builder},
@@ -821,7 +822,6 @@ static ErlNifFunc exla_funcs[] = {
821822
{"get_tuple_element", 2, get_tuple_element},
822823
// Control Flow
823824
{"conditional", 5, conditional_if},
824-
{"conditional", 3, conditional_multi},
825825
{"select", 3, select},
826826
{"while", 3, while_loop},
827827
{"call", 3, call},

exla/c_src/exla/exla_ops.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,30 +118,6 @@ ERL_NIF_TERM conditional_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
118118
return exla::nif::ok(env, exla::nif::make<xla::XlaOp>(env, op));
119119
}
120120

121-
ERL_NIF_TERM conditional_multi(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
122-
if (argc != 3) {
123-
return exla::nif::error(env, "Bad argument count.");
124-
}
125-
126-
xla::XlaOp* index;
127-
std::vector<xla::XlaComputation*> branches;
128-
std::vector<xla::XlaOp> operands;
129-
130-
if (!exla::nif::get<xla::XlaOp>(env, argv[0], index)) {
131-
return exla::nif::error(env, "Unable to get index.");
132-
}
133-
if (!exla::nif::get_list<xla::XlaComputation*>(env, argv[1], branches)) {
134-
return exla::nif::error(env, "Unable to get branches.");
135-
}
136-
if (!exla::nif::get_list<xla::XlaOp>(env, argv[2], operands)) {
137-
return exla::nif::error(env, "Unable to get operands.");
138-
}
139-
140-
xla::XlaOp op = xla::Conditional(*index, branches, operands);
141-
142-
return exla::nif::ok(env, exla::nif::make<xla::XlaOp>(env, op));
143-
}
144-
145121
ERL_NIF_TERM select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
146122
if (argc != 3) {
147123
return exla::nif::error(env, "Bad argument count.");

exla/c_src/exla/exla_ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ ERL_NIF_TERM get_tuple_element(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv
1010

1111
// Control Flow
1212
ERL_NIF_TERM conditional_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
13-
ERL_NIF_TERM conditional_multi(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
1413
ERL_NIF_TERM select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
1514
ERL_NIF_TERM while_loop(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
1615
ERL_NIF_TERM call(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);

exla/c_src/exla/mlir/builder.cc

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbersToAttr(mlir::
102102
rhsContractingVec);
103103
}
104104

105+
void MLIRFunction::dump_mlir_module() {
106+
module_->module().dump();
107+
}
108+
105109
int MLIRFunction::get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type) {
106110
auto builder = module_->builder();
107111
std::string type_str;
@@ -803,6 +807,52 @@ mlir::Value MLIRFunction::MapOp(
803807
return map_op;
804808
}
805809

810+
// adapted from xla/translate/hlo_to_mhlo/hlo_function_importer.cc
811+
// we need to adapt because we want to receive std::vector and
812+
// because we use stablehlo instead of mhlo here.
813+
void ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation *op, std::vector<mlir::Value> implicit_operands) {
814+
if (!op) {
815+
std::cerr << "op is null" << std::endl;
816+
return;
817+
}
818+
int implicit_operand_index = 0;
819+
for (auto &region : op->getRegions()) {
820+
for (auto arg : region.getArguments()) {
821+
arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
822+
}
823+
region.front().eraseArguments(0, region.getNumArguments());
824+
}
825+
}
826+
827+
mlir::Value MLIRFunction::IfOp(mlir::Value pred, xla::Shape output_shape, std::vector<mlir::Value> implicit_arguments, MLIRFunction *on_true, MLIRFunction *on_false) {
828+
auto builder = module_->builder();
829+
builder->setInsertionPointToEnd(&func_->getBody().back());
830+
831+
auto span = output_shape.dimensions();
832+
std::vector<tsl::int64> dims(span.begin(), span.end());
833+
mlir::Type output_type = GetMLIRType(builder, dims, output_shape.element_type());
834+
835+
pred = builder->create<mlir::stablehlo::ConvertOp>(builder->getUnknownLoc(), pred, builder->getIntegerType(1));
836+
837+
implicit_arguments.insert(implicit_arguments.begin(), pred);
838+
mlir::ValueRange operands(implicit_arguments);
839+
840+
mlir::stablehlo::IfOp if_op = builder->create<mlir::stablehlo::IfOp>(builder->getUnknownLoc(), mlir::TypeRange({output_type}), pred);
841+
842+
mlir::Region &trueBody = if_op.getTrueBranch();
843+
auto &onTrueBlocks = on_true->function()->getBody().getBlocks();
844+
trueBody.getBlocks().splice(trueBody.end(), onTrueBlocks);
845+
846+
mlir::Region &falseBody = if_op.getFalseBranch();
847+
auto &onFalseBlocks = on_false->function()->getBody().getBlocks();
848+
falseBody.getBlocks().splice(falseBody.end(), onFalseBlocks);
849+
850+
implicit_arguments.erase(implicit_arguments.begin());
851+
ReplaceBlockArgumentsWithImplicitOperands(if_op.getOperation(), implicit_arguments);
852+
853+
return if_op.getResult(0);
854+
}
855+
806856
mlir::Value MLIRFunction::SelectAndScatterOp(
807857
mlir::Value target,
808858
mlir::Value source,
@@ -967,8 +1017,6 @@ void MLIRFunction::Build(mlir::Value root, bool use_stablehlo_return) {
9671017
} else {
9681018
module_->builder()->create<mlir::func::ReturnOp>(module_->builder()->getUnknownLoc(), root);
9691019
}
970-
971-
module_->LowerPatterns();
9721020
}
9731021

9741022
MLIRModule::MLIRModule() {
@@ -1151,4 +1199,19 @@ void MLIRModule::LowerPatterns() {
11511199
mlir::applyPartialConversion(module(), target, std::move(patterns));
11521200
}
11531201

1202+
void MLIRModule::RemoveEmptyFunctions() {
1203+
std::vector<mlir::func::FuncOp> unused_functions;
1204+
for (auto &op : module_->getOps()) {
1205+
if (auto func = llvm::dyn_cast<mlir::func::FuncOp>(op)) {
1206+
if (func.getBody().empty()) {
1207+
unused_functions.push_back(func);
1208+
}
1209+
}
1210+
}
1211+
1212+
for (auto func : unused_functions) {
1213+
func.erase();
1214+
}
1215+
}
1216+
11541217
} // namespace exla

exla/c_src/exla/mlir/builder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class MLIRFunction {
106106
mlir::Value DynamicUpdateSliceOp(mlir::Value operand, mlir::Value update, std::vector<mlir::Value> start_indices);
107107
std::vector<mlir::Value> ReduceOp(MLIRFunction *function, std::vector<mlir::Value> init_values, std::vector<mlir::Value> inputs, std::vector<int64_t> dimensions);
108108
mlir::Value MapOp(MLIRFunction *function, std::vector<mlir::Value> inputs, std::vector<int64_t> dimensions);
109+
mlir::Value IfOp(mlir::Value pred, xla::Shape output_shape, std::vector<mlir::Value> implicit_args, MLIRFunction *on_true, MLIRFunction *on_false);
109110
ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::vector<int64_t> dims = {});
110111
int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type);
111112

@@ -118,6 +119,8 @@ class MLIRFunction {
118119
private:
119120
std::shared_ptr<MLIRModule> module_;
120121
std::unique_ptr<mlir::func::FuncOp> func_;
122+
123+
void dump_mlir_module();
121124
};
122125

123126
class MLIRModule {
@@ -133,13 +136,12 @@ class MLIRModule {
133136
mlir::OpBuilder *builder() { return builder_.get(); }
134137
mlir::MLIRContext *context() { return context_.get(); }
135138
void LowerPatterns();
139+
void RemoveEmptyFunctions();
136140

137141
private:
138142
std::unique_ptr<mlir::MLIRContext> context_;
139143
mlir::OwningOpRef<mlir::ModuleOp> module_;
140144
std::unique_ptr<mlir::OpBuilder> builder_;
141-
142-
std::vector<mlir::Type> input_types_;
143145
};
144146

145147
mlir::Type TypeIntToMLIRType(mlir::OpBuilder *builder, xla::PrimitiveType type_int);

exla/c_src/exla/mlir/ops.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
4545
return exla::nif::error(env, "Unable to get device ID.");
4646
}
4747

48+
(*module)->LowerPatterns();
49+
(*module)->RemoveEmptyFunctions();
50+
4851
build_options.set_num_replicas(num_replicas);
4952
build_options.set_num_partitions(num_partitions);
5053
build_options.set_use_spmd_partitioning(use_spmd);
@@ -831,6 +834,42 @@ ERL_NIF_TERM mlir_map(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
831834
return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, result));
832835
}
833836

837+
ERL_NIF_TERM mlir_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
838+
if (argc != 6) {
839+
return exla::nif::error(env, "Bad argument count.");
840+
}
841+
842+
exla::MLIRFunction** function;
843+
mlir::Value* pred;
844+
std::vector<mlir::Value> implicit_args;
845+
exla::MLIRFunction** on_true;
846+
exla::MLIRFunction** on_false;
847+
xla::Shape* output_shape;
848+
849+
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
850+
return exla::nif::error(env, "Unable to get function.");
851+
}
852+
if (!exla::nif::get<mlir::Value>(env, argv[1], pred)) {
853+
return exla::nif::error(env, "Unable to get pred.");
854+
}
855+
if (!exla::nif::get<xla::Shape>(env, argv[2], output_shape)) {
856+
return exla::nif::error(env, "Unable to get shape.");
857+
}
858+
if (!exla::nif::get_list<mlir::Value>(env, argv[3], implicit_args)) {
859+
return exla::nif::error(env, "Unable to get implicit_args.");
860+
}
861+
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[4], on_true)) {
862+
return exla::nif::error(env, "Unable to get on_true.");
863+
}
864+
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[5], on_false)) {
865+
return exla::nif::error(env, "Unable to get on_false.");
866+
}
867+
868+
mlir::Value result = (*function)->IfOp(*pred, *output_shape, implicit_args, *on_true, *on_false);
869+
870+
return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, result));
871+
}
872+
834873
ERL_NIF_TERM mlir_bitcast_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
835874
if (argc != 4) {
836875
return exla::nif::error(env, "Bad argument count.");

0 commit comments

Comments
 (0)