diff --git a/include/circt/Dialect/FIRRTL/FIRRTLUtils.h b/include/circt/Dialect/FIRRTL/FIRRTLUtils.h index cbd3251861ce..00c731247132 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLUtils.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLUtils.h @@ -140,6 +140,20 @@ inline FIRRTLBaseType getBaseType(FIRRTLType type) { .Case([](auto ref) { return ref.getType(); }); } +/// Return base type or passthrough if FIRRTLType, else null. +inline FIRRTLBaseType getBaseTypeOrNull(Type type) { + auto ftype = dyn_cast_or_null(type); + if (!ftype) + return {}; + return getBaseType(ftype); +} + +/// Get base type if isa<> the requested type, else null. +template +inline T getBaseOfType(Type type) { + return dyn_cast_or_null(getBaseTypeOrNull(type)); +} + /// Return a FIRRTLType with its base type component mutated by the given /// function. (i.e., ref -> ref and T -> f(T)). inline FIRRTLType mapBaseType(FIRRTLType type, diff --git a/lib/Dialect/FIRRTL/Import/FIRParser.cpp b/lib/Dialect/FIRRTL/Import/FIRParser.cpp index cb06c6bb3556..ce4dc3a2ff6d 100644 --- a/lib/Dialect/FIRRTL/Import/FIRParser.cpp +++ b/lib/Dialect/FIRRTL/Import/FIRParser.cpp @@ -1536,9 +1536,7 @@ ParseResult FIRStmtParser::parsePostFixFieldId(Value &result) { StringRef fieldName; if (parseFieldId(fieldName, "expected field name")) return failure(); - auto bundle = - dyn_cast(getBaseType(cast(result.getType()))); - + auto bundle = getBaseOfType(result.getType()); if (!bundle) return emitError(loc, "subfield requires bundle operand "); auto indexV = bundle.getElementIndex(fieldName); diff --git a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp index cd190cd51233..95b7b1bdae0f 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp @@ -1762,16 +1762,14 @@ LogicalResult InferResetsPass::verifyNoAbstractReset() { for (FModuleLike module : getOperation().getBodyBlock()->getOps()) { for (PortInfo port : module.getPorts()) { - if (auto portType = port.type.dyn_cast()) { - if (getBaseType(portType).isa()) { - auto diag = emitError(port.loc) - << "a port \"" << port.getName() - << "\" with abstract reset type was unable to be " - "inferred by InferResets (is this a top-level port?)"; - diag.attachNote(module->getLoc()) - << "the module with this uninferred reset port was defined here"; - hasAbstractResetPorts = true; - } + if (getBaseOfType(port.type)) { + auto diag = emitError(port.loc) + << "a port \"" << port.getName() + << "\" with abstract reset type was unable to be " + "inferred by InferResets (is this a top-level port?)"; + diag.attachNote(module->getLoc()) + << "the module with this uninferred reset port was defined here"; + hasAbstractResetPorts = true; } } }