From 95c011bb0ac403a275cc0f8b2b2dc6f1ca811394 Mon Sep 17 00:00:00 2001 From: John Demme Date: Mon, 17 Jul 2023 14:06:36 -0700 Subject: [PATCH] [ESI] Improve the way to/from server requests are handled (#5586) Fixes #5490. Don't break apart in/out requests. Create an op interface to avoid having to reason about three different ops. --- include/circt/Dialect/ESI/ESIInterfaces.td | 65 ++++ include/circt/Dialect/ESI/ESIServices.td | 34 +- lib/Dialect/ESI/ESIOps.cpp | 55 ---- lib/Dialect/ESI/ESIServices.cpp | 343 +++++++++------------ test/Dialect/ESI/errors.mlir | 21 ++ test/Dialect/ESI/services.mlir | 26 +- 6 files changed, 279 insertions(+), 265 deletions(-) diff --git a/include/circt/Dialect/ESI/ESIInterfaces.td b/include/circt/Dialect/ESI/ESIInterfaces.td index 8205688cff31..6864113b313e 100644 --- a/include/circt/Dialect/ESI/ESIInterfaces.td +++ b/include/circt/Dialect/ESI/ESIInterfaces.td @@ -64,3 +64,68 @@ def ServiceDeclOpInterface : OpInterface<"ServiceDeclOpInterface"> { }]>, ]; } + +def ServiceReqOpInterface : OpInterface<"ServiceReqOpInterface"> { + let cppNamespace = "circt::esi"; + let description = [{ + Any op which is requesting connection to a service's port or has information + pertaining to one of said requests. + }]; + + let methods = [ + InterfaceMethod< + "Returns the service port symbol.", + "hw::InnerRefAttr", "getServicePort", (ins) + >, + InterfaceMethod< + "Returns the client name path.", + "ArrayAttr", "getClientNamePath", (ins) + >, + InterfaceMethod< + "Set the client name path.", + "void", "setClientNamePath", (ins "ArrayAttr":$newClientNamePath), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + $_op.setClientNamePathAttr(newClientNamePath); + }]>, + InterfaceMethod< + "Returns the type headed to the client. Can be null.", + "Type", "getToClientType", (ins), + /*methodBody=*/[{ + Value toClient = $_op.getToClient(); + if (!toClient) + return {}; + return toClient.getType(); + }]>, + InterfaceMethod< + "Returns the value headed to the client. Can be null.", + "Value", "getToClient", (ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }]>, + InterfaceMethod< + "Returns the type headed to the server. Can be null.", + "Type", "getToServerType", (ins), + /*methodBody=*/[{ + Value toServer = $_op.getToServer(); + if (!toServer) + return {}; + return toServer.getType(); + }]>, + InterfaceMethod< + "Returns the value headed to the server. Can be null.", + "Value", "getToServer", (ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }]>, + InterfaceMethod< + "Set the value headed to the server. Can assert().", + "void", "setToServer", (ins "Value":$newValue), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(false && "This op doesn't support the to server direction"); + }]> + ]; +} diff --git a/include/circt/Dialect/ESI/ESIServices.td b/include/circt/Dialect/ESI/ESIServices.td index 06b7feb18132..b68142eeb269 100644 --- a/include/circt/Dialect/ESI/ESIServices.td +++ b/include/circt/Dialect/ESI/ESIServices.td @@ -120,15 +120,6 @@ def ServiceImplementReqOp : ESI_Op<"service.impl_req", [NoTerminator]> { let results = (outs Variadic:$outputs); let regions = (region SizedRegion<1>:$portReqs); - let extraClassDeclaration = [{ - /// Find pairs of toServer and toClient requests with the same client path. - /// In cases where there doesn't exist a pair, one of the two entries will - /// be null. Should never contain a pair with both entries null. - void gatherPairedReqs( - llvm::SmallVectorImpl>&); - }]; - let assemblyFormat = [{ (`svc` $service_symbol^)? `impl` `as` $impl_type (`opts` $impl_opts^)? `(` $inputs `)` attr-dict `:` functional-type($inputs, results) @@ -137,7 +128,8 @@ def ServiceImplementReqOp : ESI_Op<"service.impl_req", [NoTerminator]> { } def RequestToServerConnectionOp : ESI_Op<"service.req.to_server", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ServiceReqOpInterface]> { let summary = "Request a connection to send data"; let arguments = (ins HWInnerRefAttr:$servicePort, @@ -146,10 +138,19 @@ def RequestToServerConnectionOp : ESI_Op<"service.req.to_server", [ $toServer `->` $servicePort `(` $clientNamePath `)` attr-dict `:` qualified(type($toServer)) }]; + + let extraClassDeclaration = [{ + // Set the value headed to the server. Overrides method in + // ServiceReqOpInterface. + void setToServer(Value v) { + getToServerMutable().assign(v); + } + }]; } def RequestToClientConnectionOp : ESI_Op<"service.req.to_client", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ServiceReqOpInterface]> { let summary = "Request a connection to receive data"; let arguments = (ins HWInnerRefAttr:$servicePort, @@ -162,7 +163,8 @@ def RequestToClientConnectionOp : ESI_Op<"service.req.to_client", [ } def RequestInOutChannelOp : ESI_Op<"service.req.inout", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ServiceReqOpInterface]> { let summary = "Request a bidirectional channel"; let arguments = (ins HWInnerRefAttr:$servicePort, @@ -174,6 +176,14 @@ def RequestInOutChannelOp : ESI_Op<"service.req.inout", [ $toServer `->` $servicePort `(` $clientNamePath `)` attr-dict `:` qualified(type($toServer)) `->` qualified(type($toClient)) }]; + + let extraClassDeclaration = [{ + // Set the value headed to the server. Overrides method in + // ServiceReqOpInterface. + void setToServer(Value v) { + getToServerMutable().assign(v); + } + }]; } def ServiceHierarchyMetadataOp : ESI_Op<"service.hierarchy.metadata", [ diff --git a/lib/Dialect/ESI/ESIOps.cpp b/lib/Dialect/ESI/ESIOps.cpp index 4df8d39a5b97..1a5f49d0dbb1 100644 --- a/lib/Dialect/ESI/ESIOps.cpp +++ b/lib/Dialect/ESI/ESIOps.cpp @@ -464,61 +464,6 @@ LogicalResult ServiceHierarchyMetadataOp::verifySymbolUses( return success(); } -void ServiceImplementReqOp::gatherPairedReqs( - llvm::SmallVectorImpl> &reqPairs) { - - // Build a mapping of client path names to requests. - DenseMap, - SmallVector> - clientNameToServer; - DenseMap, - SmallVector> - clientNameToClient; - for (auto &op : getOps()) - if (auto req = dyn_cast(op)) - clientNameToClient[std::make_pair(req.getServicePort(), - req.getClientNamePathAttr())] - .push_back(req); - else if (auto req = dyn_cast(op)) - clientNameToServer[std::make_pair(req.getServicePort(), - req.getClientNamePathAttr())] - .push_back(req); - - // Find all of the pairs and emit them. - DenseSet emittedOps; - for (auto op : getOps()) { - std::pair clientName = - std::make_pair(op.getServicePort(), op.getClientNamePathAttr()); - const SmallVector &ops = - clientNameToServer[clientName]; - - // Only emit a pair if there's one toServer and one toClient request for a - // given client name path. - if (ops.size() == 1) { - auto toClientF = clientNameToClient.find(clientName); - if (toClientF != clientNameToClient.end() && - toClientF->second.size() == 1) { - reqPairs.push_back( - std::make_pair(ops.front(), toClientF->second.front())); - emittedOps.insert(ops.front()); - emittedOps.insert(toClientF->second.front()); - continue; - } - } - } - - // Emit partial pairs for all the remaining requests. - for (auto &op : getOps()) { - if (emittedOps.contains(&op)) - continue; - if (auto req = dyn_cast(op)) - reqPairs.push_back(std::make_pair(nullptr, req)); - else if (auto req = dyn_cast(op)) - reqPairs.push_back(std::make_pair(req, nullptr)); - } -} - //===----------------------------------------------------------------------===// // Structural ops. //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/ESI/ESIServices.cpp b/lib/Dialect/ESI/ESIServices.cpp index c1633d948156..a6caa90a4176 100644 --- a/lib/Dialect/ESI/ESIServices.cpp +++ b/lib/Dialect/ESI/ESIServices.cpp @@ -44,18 +44,18 @@ ServiceGeneratorDispatcher::generate(ServiceImplementReqOp req, } /// The generator for the "cosim" impl_type. -static LogicalResult instantiateCosimEndpointOps(ServiceImplementReqOp req, +static LogicalResult instantiateCosimEndpointOps(ServiceImplementReqOp implReq, ServiceDeclOpInterface) { - auto *ctxt = req.getContext(); - OpBuilder b(req); - Value clk = req.getOperand(0); - Value rst = req.getOperand(1); + auto *ctxt = implReq.getContext(); + OpBuilder b(implReq); + Value clk = implReq.getOperand(0); + Value rst = implReq.getOperand(1); // Determine which EndpointID this generator should start with. - if (req.getImplOpts()) { - auto opts = req.getImplOpts()->getValue(); + if (implReq.getImplOpts()) { + auto opts = implReq.getImplOpts()->getValue(); for (auto nameAttr : opts) { - return req.emitOpError("did not recognize option name ") + return implReq.emitOpError("did not recognize option name ") << nameAttr.getName(); } } @@ -68,74 +68,63 @@ static LogicalResult instantiateCosimEndpointOps(ServiceImplementReqOp req, return StringAttr::get(ctxt, os.str()); }; - llvm::DenseMap toClientResultNum; - for (auto toClient : req.getOps()) - toClientResultNum[toClient] = toClientResultNum.size(); - - // Get the request pairs. - llvm::SmallVector< - std::pair, 8> - reqPairs; - req.gatherPairedReqs(reqPairs); + llvm::DenseMap toClientResultNum; + for (auto req : implReq.getOps()) + if (req.getToClient()) + toClientResultNum[req] = toClientResultNum.size(); // Iterate through them, building a cosim endpoint for each one. - for (auto [toServer, toClient] : reqPairs) { - assert((toServer || toClient) && - "At least one in all pairs must be non-null"); - Location loc = toServer ? toServer.getLoc() : toClient.getLoc(); - ArrayAttr clientNamePathAttr = toServer ? toServer.getClientNamePathAttr() - : toClient.getClientNamePathAttr(); - - Value toServerValue; - if (toServer) - toServerValue = toServer.getToServer(); - else + for (auto req : implReq.getOps()) { + Location loc = req->getLoc(); + ArrayAttr clientNamePathAttr = req.getClientNamePath(); + + Value toServerValue = req.getToServer(); + if (!toServerValue) toServerValue = b.create(loc, ChannelType::get(ctxt, b.getI1Type())); - Type toClientType; - if (toClient) - toClientType = toClient.getToClient().getType(); - else + Type toClientType = req.getToClientType(); + if (!toClientType) toClientType = ChannelType::get(ctxt, b.getI1Type()); auto cosim = b.create(loc, toClientType, clk, rst, toServerValue, toStringAttr(clientNamePathAttr)); - if (toClient) { - unsigned clientReqIdx = toClientResultNum[toClient]; - req.getResult(clientReqIdx).replaceAllUsesWith(cosim.getRecv()); + if (req.getToClient()) { + unsigned clientReqIdx = toClientResultNum[req]; + implReq.getResult(clientReqIdx).replaceAllUsesWith(cosim.getRecv()); } } // Erase the generation request. - req.erase(); + implReq.erase(); return success(); } // Generator for "sv_mem" implementation type. Emits SV ops for an unpacked // array, hopefully inferred as a memory to the SV compiler. static LogicalResult -instantiateSystemVerilogMemory(ServiceImplementReqOp req, +instantiateSystemVerilogMemory(ServiceImplementReqOp implReq, ServiceDeclOpInterface decl) { if (!decl) - return req.emitOpError( + return implReq.emitOpError( "Must specify a service declaration to use 'sv_mem'."); - ImplicitLocOpBuilder b(req.getLoc(), req); - BackedgeBuilder bb(b, req.getLoc()); + ImplicitLocOpBuilder b(implReq.getLoc(), implReq); + BackedgeBuilder bb(b, implReq.getLoc()); RandomAccessMemoryDeclOp ramDecl = dyn_cast(decl.getOperation()); if (!ramDecl) - return req.emitOpError("'sv_mem' implementation type can only be used to " - "implement RandomAccessMemory declarations"); - - if (req.getNumOperands() != 2) - return req.emitOpError("Implementation requires clk and rst operands"); - auto clk = req.getOperand(0); - auto rst = req.getOperand(1); + return implReq.emitOpError( + "'sv_mem' implementation type can only be used to " + "implement RandomAccessMemory declarations"); + + if (implReq.getNumOperands() != 2) + return implReq.emitOpError("Implementation requires clk and rst operands"); + auto clk = implReq.getOperand(0); + auto rst = implReq.getOperand(1); auto write = b.getStringAttr("write"); auto read = b.getStringAttr("read"); auto none = b.create( @@ -143,45 +132,47 @@ instantiateSystemVerilogMemory(ServiceImplementReqOp req, auto i1 = b.getI1Type(); auto c0 = b.create(i1, 0); + // List of reqs which have a result. + SmallVector toClientReqs(llvm::make_filter_range( + implReq.getOps(), + [](auto req) { return req.getToClient() != nullptr; })); + // Assemble a mapping of toClient results to actual consumers. DenseMap outputMap; - for (auto [bout, reqout] : llvm::zip_longest( - req.getOps(), req.getResults())) { + for (auto [bout, reqout] : + llvm::zip_longest(toClientReqs, implReq.getResults())) { assert(bout.has_value()); assert(reqout.has_value()); - outputMap[*bout] = *reqout; + Value toClient = bout->getToClient(); + outputMap[toClient] = *reqout; } // Create the SV memory. hw::UnpackedArrayType memType = hw::UnpackedArrayType::get(ramDecl.getInnerType(), ramDecl.getDepth()); - auto mem = b.create(memType, req.getServiceSymbolAttr().getAttr()) - .getResult(); - - // Get the request pairs. - llvm::SmallVector< - std::pair, 8> - reqPairs; - req.gatherPairedReqs(reqPairs); + auto mem = + b.create(memType, implReq.getServiceSymbolAttr().getAttr()) + .getResult(); // Do everything which doesn't actually write to the memory, store the signals // needed for the actual memory writes for later. SmallVector> writeGoAddressData; - for (auto [toServerReq, toClientReq] : reqPairs) { - assert(toServerReq && toClientReq); // All of our interfaces are inout. - assert(toServerReq.getServicePort() == toClientReq.getServicePort()); - auto port = toServerReq.getServicePort().getName(); + for (auto req : implReq.getOps()) { + auto port = req.getServicePort().getName(); WrapValidReadyOp toClientResp; if (port == write) { // If this pair is doing a write... + auto ioReq = dyn_cast(*req); + if (!ioReq) + return req->emitOpError("Memory write requests must be to/from server"); // Construct the response channel. auto doneValid = bb.get(i1); toClientResp = b.create(none, doneValid); // Unwrap the write request and 'explode' the struct. - auto unwrap = b.create(toServerReq.getToServer(), + auto unwrap = b.create(ioReq.getToServer(), toClientResp.getReady()); Value address = b.create(unwrap.getRawOutput(), @@ -200,6 +191,9 @@ instantiateSystemVerilogMemory(ServiceImplementReqOp req, } else if (port == read) { // If it's a read... + auto ioReq = dyn_cast(*req); + if (!ioReq) + return req->emitOpError("Memory read requests must be to/from server"); // Construct the response channel. auto dataValid = bb.get(i1); @@ -208,7 +202,7 @@ instantiateSystemVerilogMemory(ServiceImplementReqOp req, // Unwrap the requested address and read from that memory location. auto addressUnwrap = b.create( - toServerReq.getToServer(), toClientResp.getReady()); + ioReq.getToServer(), toClientResp.getReady()); Value memLoc = b.create(mem, addressUnwrap.getRawOutput()); auto readData = b.create(memLoc); @@ -220,7 +214,7 @@ instantiateSystemVerilogMemory(ServiceImplementReqOp req, assert(false && "Port should be either 'read' or 'write'"); } - outputMap[toClientReq.getToClient()].replaceAllUsesWith( + outputMap[req.getToClient()].replaceAllUsesWith( toClientResp.getChanOutput()); } @@ -238,7 +232,7 @@ instantiateSystemVerilogMemory(ServiceImplementReqOp req, } }); - req.erase(); + implReq.erase(); return success(); } @@ -279,8 +273,7 @@ struct ESIConnectServicesPass /// module specified. Create and connect up ports to tunnel the ESI channels /// through. LogicalResult surfaceReqs(hw::HWMutableModuleLike, - ArrayRef, - ArrayRef); + ArrayRef); /// Copy all service metadata up the instance hierarchy. Modify the service /// name path while copying. @@ -338,36 +331,14 @@ LogicalResult ESIConnectServicesPass::process(hw::HWModuleLike mod) { anyServiceInst = b; } - // Decompose the 'inout' requests int to 'in' and 'out' requests. - mod.walk([&](RequestInOutChannelOp reqInOut) { - ImplicitLocOpBuilder b(reqInOut.getLoc(), reqInOut); - b.create(reqInOut.getServicePortAttr(), - reqInOut.getToServer(), - reqInOut.getClientNamePathAttr()); - auto toClientReq = b.create( - reqInOut.getToClient().getType(), reqInOut.getServicePortAttr(), - reqInOut.getClientNamePathAttr()); - reqInOut.getToClient().replaceAllUsesWith(toClientReq.getToClient()); - reqInOut.erase(); - }); - // Find all of the "local" requests. - mod.walk([&](Operation *op) { - if (auto req = dyn_cast(op)) { - auto service = req.getServicePortAttr().getModuleRef(); - auto implOpF = localImplReqs.find(service); - if (implOpF != localImplReqs.end()) - req->moveBefore(implOpF->second, implOpF->second->end()); - else if (anyServiceInst) - req->moveBefore(anyServiceInst, anyServiceInst->end()); - } else if (auto req = dyn_cast(op)) { - auto service = req.getServicePortAttr().getModuleRef(); - auto implOpF = localImplReqs.find(service); - if (implOpF != localImplReqs.end()) - req->moveBefore(implOpF->second, implOpF->second->end()); - else if (anyServiceInst) - req->moveBefore(anyServiceInst, anyServiceInst->end()); - } + mod.walk([&](ServiceReqOpInterface req) { + auto service = req.getServicePort().getModuleRef(); + auto implOpF = localImplReqs.find(service); + if (implOpF != localImplReqs.end()) + req->moveBefore(implOpF->second, implOpF->second->end()); + else if (anyServiceInst) + req->moveBefore(anyServiceInst, anyServiceInst->end()); }); // Replace each service instance with a generation request. If a service @@ -383,28 +354,20 @@ LogicalResult ESIConnectServicesPass::process(hw::HWModuleLike mod) { copyMetadata(mod); // Identify the non-local reqs which need to be surfaced from this module. - SmallVector nonLocalToClientReqs; - SmallVector nonLocalToServerReqs; - mod.walk([&](Operation *op) { - if (auto req = dyn_cast(op)) { - auto service = req.getServicePortAttr().getModuleRef(); - auto implOpF = localImplReqs.find(service); - if (implOpF == localImplReqs.end()) - nonLocalToClientReqs.push_back(req); - } else if (auto req = dyn_cast(op)) { - auto service = req.getServicePortAttr().getModuleRef(); - auto implOpF = localImplReqs.find(service); - if (implOpF == localImplReqs.end()) - nonLocalToServerReqs.push_back(req); - } + SmallVector nonLocalReqs; + mod.walk([&](ServiceReqOpInterface req) { + auto service = req.getServicePort().getModuleRef(); + auto implOpF = localImplReqs.find(service); + if (implOpF == localImplReqs.end()) + nonLocalReqs.push_back(req); }); // Surface all of the requests which cannot be fulfilled locally. - if (nonLocalToClientReqs.empty() && nonLocalToServerReqs.empty()) + if (nonLocalReqs.empty()) return success(); if (auto mutableMod = dyn_cast(mod.getOperation())) - return surfaceReqs(mutableMod, nonLocalToClientReqs, nonLocalToServerReqs); + return surfaceReqs(mutableMod, nonLocalReqs); return mod.emitOpError( "Cannot surface requests through module without mutable ports"); } @@ -442,27 +405,17 @@ static void emitServiceMetadata(ServiceImplementReqOp implReqOp) { b.setInsertionPointToStart(bspPorts.get()); } - llvm::SmallVector< - std::pair, 8> - reqPairs; - implReqOp.gatherPairedReqs(reqPairs); - SmallVector clients; - for (auto [toServer, toClient] : reqPairs) { + for (auto req : implReqOp.getOps()) { SmallVector clientAttrs; - Attribute servicePort, clientNamePath; - if (toServer) { - clientNamePath = toServer.getClientNamePathAttr(); - servicePort = toServer.getServicePortAttr(); + Attribute clientNamePath = req.getClientNamePath(); + Attribute servicePort = req.getServicePort(); + if (req.getToServerType()) clientAttrs.push_back(b.getNamedAttr( - "to_server_type", TypeAttr::get(toServer.getToServer().getType()))); - } - if (toClient) { - clientNamePath = toClient.getClientNamePathAttr(); - servicePort = toClient.getServicePortAttr(); + "to_server_type", TypeAttr::get(req.getToServerType()))); + if (req.getToClient()) clientAttrs.push_back(b.getNamedAttr( - "to_client_type", TypeAttr::get(toClient.getToClient().getType()))); - } + "to_client_type", TypeAttr::get(req.getToClientType()))); clientAttrs.push_back(b.getNamedAttr("port", servicePort)); clientAttrs.push_back(b.getNamedAttr("client_name", clientNamePath)); @@ -472,17 +425,23 @@ static void emitServiceMetadata(ServiceImplementReqOp implReqOp) { if (!bspPorts) continue; - if (toServer && toClient) - b.create( - toServer.getServicePort().getName(), - TypeAttr::get(toServer.getToServer().getType()), - TypeAttr::get(toClient.getToClient().getType())); - else if (toClient) - b.create(toClient.getServicePort().getName(), - TypeAttr::get(toClient.getToClient().getType())); - else - b.create(toServer.getServicePort().getName(), - TypeAttr::get(toServer.getToServer().getType())); + llvm::TypeSwitch(req) + .Case([&](RequestInOutChannelOp) { + assert(req.getToClientType()); + assert(req.getToServerType()); + b.create(req.getServicePort().getName(), + TypeAttr::get(req.getToServerType()), + TypeAttr::get(req.getToClientType())); + }) + .Case([&](RequestToClientConnectionOp) { + b.create(req.getServicePort().getName(), + TypeAttr::get(req.getToClientType())); + }) + .Case([&](RequestToServerConnectionOp) { + b.create(req.getServicePort().getName(), + TypeAttr::get(req.getToServerType())); + }) + .Default([](Operation *) {}); } if (bspPorts && !bspPorts->empty()) { @@ -519,8 +478,9 @@ LogicalResult ESIConnectServicesPass::replaceInst(ServiceInstanceOp instOp, // + the to_client types. SmallVector resultTypes(instOp.getResultTypes().begin(), instOp.getResultTypes().end()); - for (auto toClient : portReqs->getOps()) - resultTypes.push_back(toClient.getToClient().getType()); + for (auto req : portReqs->getOps()) + if (auto t = req.getToClientType()) + resultTypes.push_back(t); // Create the generation request op. OpBuilder b(instOp); @@ -534,10 +494,14 @@ LogicalResult ESIConnectServicesPass::replaceInst(ServiceInstanceOp instOp, for (auto [n, o] : llvm::zip(implOp.getResults(), instOp.getResults())) o.replaceAllUsesWith(n); unsigned instOpNumResults = instOp.getNumResults(); - for (auto e : - llvm::enumerate(portReqs->getOps())) - e.value().getToClient().replaceAllUsesWith( - implOp.getResult(e.index() + instOpNumResults)); + for (auto [idx, req] : llvm::enumerate( + llvm::make_filter_range(portReqs->getOps(), + [](ServiceReqOpInterface req) -> bool { + return req.getToClient() != nullptr; + }))) { + req.getToClient().replaceAllUsesWith( + implOp.getResult(idx + instOpNumResults)); + } emitServiceMetadata(implOp); @@ -549,10 +513,9 @@ LogicalResult ESIConnectServicesPass::replaceInst(ServiceInstanceOp instOp, return success(); } -LogicalResult ESIConnectServicesPass::surfaceReqs( - hw::HWMutableModuleLike mod, - ArrayRef toClientReqs, - ArrayRef toServerReqs) { +LogicalResult +ESIConnectServicesPass::surfaceReqs(hw::HWMutableModuleLike mod, + ArrayRef reqs) { auto *ctxt = mod.getContext(); Block *body = &mod->getRegion(0).front(); @@ -574,31 +537,30 @@ LogicalResult ESIConnectServicesPass::surfaceReqs( }; // Insert new module input ESI ports. - for (auto toClient : toClientReqs) { + for (auto req : reqs) { + Type toClientType = req.getToClientType(); + if (!toClientType) + continue; newInputs.push_back(std::make_pair( - origNumInputs, - hw::PortInfo{getPortName(toClient.getClientNamePath()), - hw::PortDirection::INPUT, toClient.getType(), - origNumInputs, nullptr, toClient.getLoc()})); - body->addArgument(toClient.getType(), toClient.getLoc()); - } - mod.insertPorts(newInputs, {}); + origNumInputs, hw::PortInfo{getPortName(req.getClientNamePath()), + hw::PortDirection::INPUT, toClientType, + origNumInputs, nullptr, req->getLoc()})); - // Replace uses with new block args which will correspond to said ports. - // Note: no zip or enumerate here because we need mutable access to - // toClientReqs. - int i = 0; - for (auto toClient : toClientReqs) { - toClient.replaceAllUsesWith(body->getArguments()[origNumInputs + i]); - ++i; + // Replace uses with new block args which will correspond to said ports. + Value replValue = body->addArgument(toClientType, req->getLoc()); + req.getToClient().replaceAllUsesWith(replValue); } + mod.insertPorts(newInputs, {}); // Append output ports to new port list and redirect toServer inputs to // output op. unsigned outputCounter = origNumOutputs; - for (auto toServer : toServerReqs) - newOutputs.push_back( - {getPortName(toServer.getClientNamePath()), toServer.getToServer()}); + for (auto req : reqs) { + Value toServer = req.getToServer(); + if (!toServer) + continue; + newOutputs.push_back({getPortName(req.getClientNamePath()), toServer}); + } mod.appendOutputs(newOutputs); @@ -626,17 +588,22 @@ LogicalResult ESIConnectServicesPass::surfaceReqs( // Add new inputs for the new to_client requests and clone the request // into the module containing `inst`. - for (auto [toClient, newPort] : llvm::zip(toClientReqs, newInputs)) { - auto instToClient = cast(b.clone(*toClient)); - instToClient.setClientNamePathAttr(prependNamePart( - instToClient.getClientNamePath(), inst.getInstanceName())); - newOperands.push_back(instToClient.getToClient()); + circt::BackedgeBuilder beb(b, mod.getLoc()); + SmallVector newResultBackedges; + for (auto req : reqs) { + auto clone = cast(b.clone(*req)); + clone.setClientNamePath( + prependNamePart(clone.getClientNamePath(), inst.getInstanceName())); + if (Value toClient = clone.getToClient()) + newOperands.push_back(toClient); + if (Type toServerType = clone.getToServerType()) { + newResultTypes.push_back(toServerType); + Backedge result = beb.get(toServerType); + newResultBackedges.push_back(result); + clone.setToServer(result); + } } - // Append the results for the to_server requests. - for (auto newPort : newOutputs) - newResultTypes.push_back(newPort.second.getType()); - // Create a replacement instance of the same operation type. SmallVector newAttrs; for (auto attr : inst->getAttrs()) { @@ -648,25 +615,21 @@ LogicalResult ESIConnectServicesPass::surfaceReqs( else newAttrs.push_back(attr); } - auto *newHWInst = b.insert(Operation::create( + auto *newInst = b.insert(Operation::create( inst->getLoc(), inst->getName(), newResultTypes, newOperands, b.getDictionaryAttr(newAttrs), inst->getPropertiesStorage(), inst->getSuccessors(), inst->getRegions())); - newModuleInstantiations.push_back(cast(newHWInst)); + newModuleInstantiations.push_back(cast(newInst)); // Replace all uses of the instance being replaced. for (auto [newV, oldV] : - llvm::zip(newHWInst->getResults(), inst->getResults())) + llvm::zip(newInst->getResults(), inst->getResults())) oldV.replaceAllUsesWith(newV); // Clone the to_server requests and wire them up to the new instance. outputCounter = origNumOutputs; - for (auto [toServer, newPort] : llvm::zip(toServerReqs, newOutputs)) { - auto instToServer = cast(b.clone(*toServer)); - instToServer.setClientNamePathAttr(prependNamePart( - instToServer.getClientNamePath(), inst.getInstanceName())); - instToServer->setOperand(0, newHWInst->getResult(outputCounter++)); - } + for (Backedge newResult : newResultBackedges) + newResult.setValue(newInst->getResult(outputCounter++)); } // Replace the list of instantiations and erase the old ones. @@ -676,10 +639,8 @@ LogicalResult ESIConnectServicesPass::surfaceReqs( // Erase the original requests since they have been cloned into the proper // destination modules. - for (auto toClient : toClientReqs) - toClient.erase(); - for (auto toServer : toServerReqs) - toServer.erase(); + for (auto req : reqs) + req.erase(); return success(); } diff --git a/test/Dialect/ESI/errors.mlir b/test/Dialect/ESI/errors.mlir index a895c99366f1..fe2aeafd1813 100644 --- a/test/Dialect/ESI/errors.mlir +++ b/test/Dialect/ESI/errors.mlir @@ -79,6 +79,27 @@ hw.module @Loopback (%clk: i1) -> () { // ----- +esi.mem.ram @MemA i64 x 20 +!write = !hw.struct +hw.module @MemoryAccess1(%clk: i1, %rst: i1, %write: !esi.channel) -> () { + // expected-error @+1 {{'esi.service.instance' op failed to generate server}} + esi.service.instance svc @MemA impl as "sv_mem" (%clk, %rst) : (i1, i1) -> () + // expected-error @+1 {{'esi.service.req.to_server' op Memory write requests must be to/from server}} + esi.service.req.to_server %write -> <@MemA::@write> ([]) : !esi.channel +} + +// ----- + +esi.mem.ram @MemA i64 x 20 +hw.module @MemoryAccess1(%clk: i1, %rst: i1, %addr: !esi.channel) -> () { + // expected-error @+1 {{'esi.service.instance' op failed to generate server}} + esi.service.instance svc @MemA impl as "sv_mem" (%clk, %rst) : (i1, i1) -> () + // expected-error @+1 {{'esi.service.req.to_server' op Memory read requests must be to/from server}} + esi.service.req.to_server %addr -> <@MemA::@read> ([]) : !esi.channel +} + +// ----- + esi.service.decl @HostComms { esi.service.inout @ReqResp : !esi.channel -> !esi.channel } diff --git a/test/Dialect/ESI/services.mlir b/test/Dialect/ESI/services.mlir index 3a469e18501c..222361553d4a 100644 --- a/test/Dialect/ESI/services.mlir +++ b/test/Dialect/ESI/services.mlir @@ -37,11 +37,11 @@ hw.module @Loopback (%clk: i1) -> () { // CONN-LABEL: hw.module @Top2(%clk: i1) -> (chksum: i8) { // CONN: [[r0:%.+]]:3 = esi.service.impl_req svc @HostComms impl as "topComms2"(%clk) : (i1) -> (i8, !esi.channel, !esi.channel) { -// CONN: %1 = esi.service.req.to_client <@HostComms::@Recv>(["r1", "m1", "loopback_tohw"]) : !esi.channel -// CONN: %2 = esi.service.req.to_client <@HostComms::@Recv>(["r1", "c1", "consumingFromChan"]) : !esi.channel -// CONN: esi.service.req.to_server %r1.m1.loopback_fromhw -> <@HostComms::@Send>(["r1", "m1", "loopback_fromhw"]) : !esi.channel -// CONN: esi.service.req.to_server %r1.p1.producedMsgChan -> <@HostComms::@Send>(["r1", "p1", "producedMsgChan"]) : !esi.channel -// CONN: esi.service.req.to_server %r1.p2.producedMsgChan -> <@HostComms::@Send>(["r1", "p2", "producedMsgChan"]) : !esi.channel +// CONN-DAG: esi.service.req.to_client <@HostComms::@Recv>(["r1", "m1", "loopback_tohw"]) : !esi.channel +// CONN-DAG: esi.service.req.to_client <@HostComms::@Recv>(["r1", "c1", "consumingFromChan"]) : !esi.channel +// CONN-DAG: esi.service.req.to_server %r1.m1.loopback_fromhw -> <@HostComms::@Send>(["r1", "m1", "loopback_fromhw"]) : !esi.channel +// CONN-DAG: esi.service.req.to_server %r1.p1.producedMsgChan -> <@HostComms::@Send>(["r1", "p1", "producedMsgChan"]) : !esi.channel +// CONN-DAG: esi.service.req.to_server %r1.p2.producedMsgChan -> <@HostComms::@Send>(["r1", "p2", "producedMsgChan"]) : !esi.channel // CONN: } // CONN: %r1.m1.loopback_fromhw, %r1.p1.producedMsgChan, %r1.p2.producedMsgChan = hw.instance "r1" @Rec(clk: %clk: i1, m1.loopback_tohw: [[r0]]#1: !esi.channel, c1.consumingFromChan: [[r0]]#2: !esi.channel) -> (m1.loopback_fromhw: !esi.channel, p1.producedMsgChan: !esi.channel, p2.producedMsgChan: !esi.channel) // CONN: hw.output [[r0]]#0 : i8 @@ -112,8 +112,7 @@ msft.module @MsLoopback {} (%clk: i1) -> () { // CONN-LABEL: msft.module @InOutTop {} (%clk: i1) -> (chksum: i8) { // CONN: %0:2 = esi.service.impl_req svc @HostComms impl as "topComms"(%clk) : (i1) -> (i8, !esi.channel) { -// CONN: %1 = esi.service.req.to_client <@HostComms::@ReqResp>(["m1", "loopback_inout"]) : !esi.channel -// CONN: esi.service.req.to_server %m1.loopback_inout -> <@HostComms::@ReqResp>(["m1", "loopback_inout"]) : !esi.channel +// CONN: esi.service.req.inout %m1.loopback_inout -> <@HostComms::@ReqResp>(["m1", "loopback_inout"]) : !esi.channel -> !esi.channel // CONN: } // CONN: %m1.loopback_inout = msft.instance @m1 @InOutLoopback(%clk, %0#1) : (i1, !esi.channel) -> !esi.channel // CONN: msft.output %0#0 : i8 @@ -195,3 +194,16 @@ hw.module @MemoryAccess1(%clk: i1, %rst: i1, %write: !esi.channel, %read %readData = esi.service.req.inout %readAddress -> <@MemA::@read> ([]) : !esi.channel -> !esi.channel hw.output %readData, %done : !esi.channel, !esi.channel } + +// CONN-LABEL: hw.module @MemoryAccess2Read(%clk: i1, %rst: i1, %write: !esi.channel>, %readAddress: !esi.channel, %readAddress2: !esi.channel) -> (readData: !esi.channel, readData2: !esi.channel, writeDone: !esi.channel) { +// CONN: %MemA = sv.reg : !hw.inout> +// CONN: esi.service.hierarchy.metadata path [] implementing @MemA impl as "sv_mem" clients [{client_name = [], port = #hw.innerNameRef<@MemA::@write>, to_client_type = !esi.channel, to_server_type = !esi.channel>}, {client_name = [], port = #hw.innerNameRef<@MemA::@read>, to_client_type = !esi.channel, to_server_type = !esi.channel}, {client_name = [], port = #hw.innerNameRef<@MemA::@read>, to_client_type = !esi.channel, to_server_type = !esi.channel}] +// CONN: hw.output %chanOutput_0, %chanOutput_4, %chanOutput : !esi.channel, !esi.channel, !esi.channel + +hw.module @MemoryAccess2Read(%clk: i1, %rst: i1, %write: !esi.channel, %readAddress: !esi.channel, %readAddress2: !esi.channel) -> (readData: !esi.channel, readData2: !esi.channel, writeDone: !esi.channel) { + esi.service.instance svc @MemA impl as "sv_mem" (%clk, %rst) : (i1, i1) -> () + %done = esi.service.req.inout %write -> <@MemA::@write> ([]) : !esi.channel -> !esi.channel + %readData = esi.service.req.inout %readAddress -> <@MemA::@read> ([]) : !esi.channel -> !esi.channel + %readData2 = esi.service.req.inout %readAddress2 -> <@MemA::@read> ([]) : !esi.channel -> !esi.channel + hw.output %readData, %readData2, %done : !esi.channel, !esi.channel, !esi.channel +}