diff --git a/include/circt/Dialect/HW/HWTypes.h b/include/circt/Dialect/HW/HWTypes.h index 72d0e786e2e6..6b11e6a745f6 100644 --- a/include/circt/Dialect/HW/HWTypes.h +++ b/include/circt/Dialect/HW/HWTypes.h @@ -31,6 +31,50 @@ struct ModulePort { Direction dir; }; +static bool operator==(const ModulePort &a, const ModulePort &b) { + return a.dir == b.dir && a.name == b.name && a.type == b.type; +} +static llvm::hash_code hash_value(const ModulePort &port) { + return llvm::hash_combine(port.dir, port.name, port.type); +} + +namespace detail { +struct ModuleTypeStorage : public TypeStorage { + ModuleTypeStorage(ArrayRef inPorts); + + using KeyTy = ArrayRef; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const { + return std::equal(key.begin(), key.end(), ports.begin(), ports.end()); + } + + /// Define a hash function for the key type. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine_range(key.begin(), key.end()); + } + + /// Define a construction method for creating a new instance of this storage. + static ModuleTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) ModuleTypeStorage(key); + } + + /// Construct an instance of the key from this storage class. + KeyTy getAsKey() const { return ports; } + + ArrayRef getPorts() const { return ports; } + + /// The parametric data held by the storage class. + SmallVector ports; + // Cache of common lookups + SmallVector inputToAbs; + SmallVector outputToAbs; + SmallVector absToInput; + SmallVector absToOutput; +}; +} // namespace detail + class HWSymbolCache; class ParamDeclAttr; class TypedeclOp; diff --git a/include/circt/Dialect/HW/HWTypesImpl.td b/include/circt/Dialect/HW/HWTypesImpl.td index 2f165a2a17ad..e3ee7078e7a7 100644 --- a/include/circt/Dialect/HW/HWTypesImpl.td +++ b/include/circt/Dialect/HW/HWTypesImpl.td @@ -243,6 +243,7 @@ def ModuleTypeImpl : HWType<"Module"> { let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; let mnemonic = "modty"; + let genStorageClass = 0; let extraClassDeclaration = [{ // Many of these are transitional and will be removed when modules and instances diff --git a/lib/CAPI/Dialect/FIRRTL.cpp b/lib/CAPI/Dialect/FIRRTL.cpp index 48cb961f1860..ab731b6ee3c7 100644 --- a/lib/CAPI/Dialect/FIRRTL.cpp +++ b/lib/CAPI/Dialect/FIRRTL.cpp @@ -311,6 +311,7 @@ FIRRTLValueFlow firrtlValueFoldFlow(MlirValue value, FIRRTLValueFlow flow) { case Flow::Duplex: return FIRRTL_VALUE_FLOW_DUPLEX; } + llvm_unreachable("invalid flow"); } bool firrtlImportAnnotationsFromJSONRaw( diff --git a/lib/Dialect/HW/HWOps.cpp b/lib/Dialect/HW/HWOps.cpp index aec40c91e433..d3eb281c5927 100644 --- a/lib/Dialect/HW/HWOps.cpp +++ b/lib/Dialect/HW/HWOps.cpp @@ -1075,8 +1075,6 @@ static LogicalResult verifyModuleCommon(HWModuleLike module) { assert(isa(module) && "verifier hook should only be called on modules"); - auto moduleType = module.getHWModuleType(); - SmallPtrSet paramNames; // Check parameter default values are sensible. diff --git a/lib/Dialect/HW/HWTypes.cpp b/lib/Dialect/HW/HWTypes.cpp index 2b52e9dd9245..e2f09433c42f 100644 --- a/lib/Dialect/HW/HWTypes.cpp +++ b/lib/Dialect/HW/HWTypes.cpp @@ -823,60 +823,30 @@ LogicalResult ModuleType::verify(function_ref emitError, } size_t ModuleType::getPortIdForInputId(size_t idx) { - for (auto [i, p] : llvm::enumerate(getPorts())) { - if (p.dir != ModulePort::Direction::Output) { - if (!idx) - return i; - --idx; - } - } - assert(0 && "Out of bounds input port id"); - return ~0UL; + assert(idx < getImpl()->inputToAbs.size() && "input port out of range"); + return getImpl()->inputToAbs[idx]; } size_t ModuleType::getPortIdForOutputId(size_t idx) { - for (auto [i, p] : llvm::enumerate(getPorts())) { - if (p.dir == ModulePort::Direction::Output) { - if (!idx) - return i; - --idx; - } - } - assert(0 && "Out of bounds output port id"); - return ~0UL; + assert(idx < getImpl()->outputToAbs.size() && " output port out of range"); + return getImpl()->outputToAbs[idx]; } size_t ModuleType::getInputIdForPortId(size_t idx) { - auto ports = getPorts(); - assert(ports[idx].dir != ModulePort::Direction::Output); - size_t retval = 0; - for (size_t i = 0; i < idx; ++i) - if (ports[i].dir != ModulePort::Direction::Output) - ++retval; - return retval; + auto nIdx = getImpl()->absToInput[idx]; + assert(nIdx != ~0ULL); + return nIdx; } size_t ModuleType::getOutputIdForPortId(size_t idx) { - auto ports = getPorts(); - assert(ports[idx].dir == ModulePort::Direction::Output); - size_t retval = 0; - for (size_t i = 0; i < idx; ++i) - if (ports[i].dir == ModulePort::Direction::Output) - ++retval; - return retval; + auto nIdx = getImpl()->absToOutput[idx]; + assert(nIdx != ~0ULL); + return nIdx; } -size_t ModuleType::getNumInputs() { - return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { - return p.dir != ModulePort::Direction::Output; - }); -} +size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); } -size_t ModuleType::getNumOutputs() { - return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { - return p.dir == ModulePort::Direction::Output; - }); -} +size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); } size_t ModuleType::getNumPorts() { return getPorts().size(); } @@ -984,6 +954,10 @@ FunctionType ModuleType::getFuncType() { return FunctionType::get(getContext(), inputs, outputs); } +ArrayRef ModuleType::getPorts() const { + return getImpl()->getPorts(); +} + FailureOr ModuleType::resolveParametricTypes(ArrayAttr parameters, LocationAttr loc, bool emitErrors) { @@ -1021,7 +995,7 @@ static ModulePort::Direction strToDir(StringRef str) { } /// Parse a list of field names and types within <>. E.g.: -/// +/// static ParseResult parsePorts(AsmParser &p, SmallVectorImpl &ports) { return p.parseCommaSeparatedList( @@ -1060,18 +1034,6 @@ void ModuleType::print(AsmPrinter &odsPrinter) const { printPorts(odsPrinter, getPorts()); } -namespace circt { -namespace hw { - -static bool operator==(const ModulePort &a, const ModulePort &b) { - return a.dir == b.dir && a.name == b.name && a.type == b.type; -} -static llvm::hash_code hash_value(const ModulePort &port) { - return llvm::hash_combine(port.dir, port.name, port.type); -} -} // namespace hw -} // namespace circt - ModuleType circt::hw::detail::fnToMod(Operation *op, ArrayRef inputNames, ArrayRef outputNames) { @@ -1109,6 +1071,25 @@ ModuleType circt::hw::detail::fnToMod(FunctionType fnty, return ModuleType::get(fnty.getContext(), ports); } +detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef inPorts) + : ports(inPorts) { + size_t nextInput = 0; + size_t nextOutput = 0; + for (auto [idx, p] : llvm::enumerate(ports)) { + if (p.dir == ModulePort::Direction::Output) { + outputToAbs.push_back(idx); + absToOutput.push_back(nextOutput); + absToInput.push_back(~0ULL); + ++nextOutput; + } else { + inputToAbs.push_back(idx); + absToInput.push_back(nextInput); + absToOutput.push_back(~0ULL); + ++nextInput; + } + } +} + //////////////////////////////////////////////////////////////////////////////// // BoilerPlate ////////////////////////////////////////////////////////////////////////////////