diff --git a/include/circt/Dialect/RTGTest/IR/RTGTestDialect.h b/include/circt/Dialect/RTGTest/IR/RTGTestDialect.h index ae8fe4678e05..2e4c831c5284 100644 --- a/include/circt/Dialect/RTGTest/IR/RTGTestDialect.h +++ b/include/circt/Dialect/RTGTest/IR/RTGTestDialect.h @@ -14,6 +14,7 @@ #ifndef CIRCT_DIALECT_RTGTEST_IR_RTGTESTDIALECT_H #define CIRCT_DIALECT_RTGTEST_IR_RTGTESTDIALECT_H +#include "circt/Dialect/RTG/IR/RTGDialect.h" #include "circt/Support/LLVM.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" diff --git a/include/circt/Dialect/RTGTest/IR/RTGTestDialect.td b/include/circt/Dialect/RTGTest/IR/RTGTestDialect.td index fe9a00d2fccb..e651b11bbbf2 100644 --- a/include/circt/Dialect/RTGTest/IR/RTGTestDialect.td +++ b/include/circt/Dialect/RTGTest/IR/RTGTestDialect.td @@ -27,6 +27,11 @@ def RTGTestDialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + + let hasConstantMaterializer = 1; + + let dependentDialects = ["::circt::rtg::RTGDialect"]; + let cppNamespace = "::circt::rtgtest"; let extraClassDeclaration = [{ diff --git a/lib/Dialect/RTGTest/IR/RTGTestDialect.cpp b/lib/Dialect/RTGTest/IR/RTGTestDialect.cpp index f93f1511a990..cd3fb0100fa9 100644 --- a/lib/Dialect/RTGTest/IR/RTGTestDialect.cpp +++ b/lib/Dialect/RTGTest/IR/RTGTestDialect.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/RTGTest/IR/RTGTestDialect.h" +#include "circt/Dialect/RTG/IR/RTGOps.h" #include "circt/Dialect/RTGTest/IR/RTGTestOps.h" #include "circt/Dialect/RTGTest/IR/RTGTestTypes.h" #include "mlir/IR/Builders.h" @@ -34,6 +35,39 @@ void RTGTestDialect::initialize() { >(); } +/// Registered hook to materialize a single constant operation from a given +/// attribute value with the desired resultant type. This method should use +/// the provided builder to create the operation without changing the +/// insertion position. The generated operation is expected to be constant +/// like, i.e. single result, zero operands, non side-effecting, etc. On +/// success, this hook should return the value generated to represent the +/// constant value. Otherwise, it should return null on failure. +Operation *RTGTestDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + if (auto attr = dyn_cast(value)) + if (isa(type)) + return builder.create(loc, type, attr); + + if (auto attr = dyn_cast(value)) + if (isa(type)) + return builder.create(loc, attr); + + if (auto attr = dyn_cast(value)) + if (isa(type)) + return builder.create(loc, attr); + + if (auto attr = dyn_cast(value)) + if (isa(type)) + return builder.create(loc, attr); + + if (auto attr = dyn_cast(value)) + if (isa(type)) + return builder.create(loc, attr); + + return nullptr; +} + #include "circt/Dialect/RTGTest/IR/RTGTestEnums.cpp.inc" #include "circt/Dialect/RTGTest/IR/RTGTestDialect.cpp.inc" diff --git a/unittests/Dialect/CMakeLists.txt b/unittests/Dialect/CMakeLists.txt index e5c67d266fdb..932d4bf03010 100644 --- a/unittests/Dialect/CMakeLists.txt +++ b/unittests/Dialect/CMakeLists.txt @@ -3,4 +3,5 @@ add_subdirectory(FIRRTL) add_subdirectory(ESI) add_subdirectory(HW) add_subdirectory(OM) +add_subdirectory(RTGTest) add_subdirectory(SMT) diff --git a/unittests/Dialect/RTGTest/CMakeLists.txt b/unittests/Dialect/RTGTest/CMakeLists.txt new file mode 100644 index 000000000000..0370a8f9e4d5 --- /dev/null +++ b/unittests/Dialect/RTGTest/CMakeLists.txt @@ -0,0 +1,9 @@ +add_circt_unittest(CIRCTRTGTestTests + MaterializerTest.cpp +) + +target_link_libraries(CIRCTRTGTestTests + PRIVATE + CIRCTRTGTestDialect + MLIRIR +) diff --git a/unittests/Dialect/RTGTest/MaterializerTest.cpp b/unittests/Dialect/RTGTest/MaterializerTest.cpp new file mode 100644 index 000000000000..673bbf1d5e78 --- /dev/null +++ b/unittests/Dialect/RTGTest/MaterializerTest.cpp @@ -0,0 +1,71 @@ +//===- MaterializerTest.cpp - RTGTest Dialect Materializer unit tests -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/RTGTest/IR/RTGTestAttributes.h" +#include "circt/Dialect/RTGTest/IR/RTGTestDialect.h" +#include "circt/Dialect/RTGTest/IR/RTGTestOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace circt; +using namespace rtgtest; + +namespace { + +TEST(MaterializerTest, CPUAttr) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + auto moduleOp = ModuleOp::create(loc); + OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto attr = CPUAttr::get(&context, 0); + auto *op = context.getLoadedDialect()->materializeConstant( + builder, attr, attr.getType(), loc); + ASSERT_TRUE(op && isa(op)); +} + +TEST(MaterializerTest, ImmediateAttr) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + auto moduleOp = ModuleOp::create(loc); + OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto attr12 = Imm12Attr::get(&context, 0); + auto attr21 = Imm32Attr::get(&context, 0); + auto attr32 = Imm21Attr::get(&context, 0); + + auto *op12 = context.getLoadedDialect()->materializeConstant( + builder, attr12, attr12.getType(), loc); + auto *op21 = context.getLoadedDialect()->materializeConstant( + builder, attr21, attr21.getType(), loc); + auto *op32 = context.getLoadedDialect()->materializeConstant( + builder, attr32, attr32.getType(), loc); + + ASSERT_TRUE(op12 && isa(op12)); + ASSERT_TRUE(op21 && isa(op21)); + ASSERT_TRUE(op32 && isa(op32)); +} + +TEST(MaterializerTest, RegisterAttr) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + auto moduleOp = ModuleOp::create(loc); + OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto attr = RegZeroAttr::get(&context); + auto *op = context.getLoadedDialect()->materializeConstant( + builder, attr, attr.getType(), loc); + ASSERT_TRUE(op && isa(op)); +} + +} // namespace