From 7b649c5c8773a20d0a063761b395f14e4f02fff2 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Fri, 10 Dec 2021 17:50:13 +0900 Subject: [PATCH] [FIRRTL/ExpandWhens] Support aggregate type registers (#2305) This commit changes ExpandWhens to handle aggregate type registers. We need to expand the connection into individual ground type elements. --- lib/Dialect/FIRRTL/Transforms/ExpandWhens.cpp | 56 +++++++++++++------ test/Dialect/FIRRTL/expand-whens-errors.mlir | 11 +--- test/Dialect/FIRRTL/expand-whens.mlir | 18 ++++++ 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/ExpandWhens.cpp b/lib/Dialect/FIRRTL/Transforms/ExpandWhens.cpp index 9c70b93ffa5e..7ff1b4a97d28 100644 --- a/lib/Dialect/FIRRTL/Transforms/ExpandWhens.cpp +++ b/lib/Dialect/FIRRTL/Transforms/ExpandWhens.cpp @@ -204,28 +204,48 @@ class LastConnectResolver : public FIRRTLVisitor { void visitDecl(WireOp op) { declareSinks(op.result(), Flow::Duplex); } + /// Take an aggregate value and construct ground subelements recursively. + /// And then apply function `fn`. + void foreachSubelement(OpBuilder &builder, Value value, + llvm::function_ref fn) { + TypeSwitch(value.getType()) + .template Case([&](BundleType bundle) { + for (auto i : llvm::seq(0u, (unsigned)bundle.getNumElements())) { + auto subfield = + builder.create(value.getLoc(), value, i); + foreachSubelement(builder, subfield, fn); + } + }) + .template Case([&](FVectorType vector) { + for (auto i : llvm::seq(0u, vector.getNumElements())) { + auto subindex = + builder.create(value.getLoc(), value, i); + foreachSubelement(builder, subindex, fn); + } + }) + .Default([&](auto) { fn(value); }); + } + void visitDecl(RegOp op) { - // Registers are initialized to themselves. - // TODO: register of aggregate types are not supported. - if (!op.getType().cast().isGround()) { - op.emitError() << "aggegate type register is not supported"; - return; - } - auto connect = OpBuilder(op->getBlock(), ++Block::iterator(op)) - .create(op.getLoc(), op, op); - driverMap[getFieldRefFromValue(op.result())] = connect; + // Registers are initialized to themselves. If the register has an + // aggergate type, connect each ground type element. + auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op)); + auto fn = [&](Value value) { + auto connect = builder.create(value.getLoc(), value, value); + driverMap[getFieldRefFromValue(value)] = connect; + }; + foreachSubelement(builder, op.result(), fn); } void visitDecl(RegResetOp op) { - // Registers are initialized to themselves. - // TODO: register of aggregate types are not supported. - if (!op.getType().cast().isGround()) { - op.emitError() << "aggegate type register is not supported"; - return; - } - auto connect = OpBuilder(op->getBlock(), ++Block::iterator(op)) - .create(op.getLoc(), op, op); - driverMap[getFieldRefFromValue(op.result())] = connect; + // Registers are initialized to themselves. If the register has an + // aggergate type, connect each ground type element. + auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op)); + auto fn = [&](Value value) { + auto connect = builder.create(value.getLoc(), value, value); + driverMap[getFieldRefFromValue(value)] = connect; + }; + foreachSubelement(builder, op.result(), fn); } void visitDecl(InstanceOp op) { diff --git a/test/Dialect/FIRRTL/expand-whens-errors.mlir b/test/Dialect/FIRRTL/expand-whens-errors.mlir index d8b27c2c9753..6e03f907273e 100644 --- a/test/Dialect/FIRRTL/expand-whens-errors.mlir +++ b/test/Dialect/FIRRTL/expand-whens-errors.mlir @@ -128,13 +128,4 @@ firrtl.circuit "CheckInitialization" { firrtl.module @CheckInitialization(in %p : !firrtl.uint<1>, out %out: !firrtl.vector, b:uint<1>>, 1>) { // expected-error @-1 {{sink "out[0].a" not fully initialized}} } -} - -// ----- - -firrtl.circuit "CheckInitialization" { -firrtl.module @CheckInitialization(in %clock: !firrtl.clock) { - // expected-error @+1 {{aggegate type register is not supported}} - %reg0 = firrtl.reg %clock : !firrtl.vector, 1> -} -} +} \ No newline at end of file diff --git a/test/Dialect/FIRRTL/expand-whens.mlir b/test/Dialect/FIRRTL/expand-whens.mlir index 106f390816f1..80566b27ee1f 100644 --- a/test/Dialect/FIRRTL/expand-whens.mlir +++ b/test/Dialect/FIRRTL/expand-whens.mlir @@ -512,4 +512,22 @@ firrtl.module @vector_of_bundle(in %p : !firrtl.uint<1>, out %ret: !firrtl.vecto // CHECK-NOT: firrtl.connect %1, %c0_ui1 : !firrtl.uint<1>, !firrtl.uint<1> // CHECK: firrtl.connect %1, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1> } + +// CHECK-LABEL: @aggregate_register +firrtl.module @aggregate_register(in %clock: !firrtl.clock) { + %0 = firrtl.reg %clock : !firrtl.bundle, b : uint<1>> + // CHECK: %1 = firrtl.subfield %0(0) + // CHECK-NEXT: firrtl.connect %1, %1 + // CHECK-NEXT: %2 = firrtl.subfield %0(1) + // CHECK-NEXT: firrtl.connect %2, %2 +} + +// CHECK-LABEL: @aggregate_regreset +firrtl.module @aggregate_regreset(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %resetval: !firrtl.vector, 2>) { + %0 = firrtl.regreset %clock, %reset, %resetval : !firrtl.uint<1>, !firrtl.vector, 2>, !firrtl.vector, 2> + // CHECK: %1 = firrtl.subindex %0[0] + // CHECK-NEXT: firrtl.connect %1, %1 + // CHECK-NEXT: %2 = firrtl.subindex %0[1] + // CHECK-NEXT: firrtl.connect %2, %2 +} }