diff --git a/integration_test/Dialect/ESI/runtime/loopback.mlir.py b/integration_test/Dialect/ESI/runtime/loopback.mlir.py index 8bf332af4473..a942a45b232b 100644 --- a/integration_test/Dialect/ESI/runtime/loopback.mlir.py +++ b/integration_test/Dialect/ESI/runtime/loopback.mlir.py @@ -94,4 +94,7 @@ print(f"result: {result}") if platform != "trace": assert result == [-21, -22] + +acc = None + print("PASS") diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h index c3a1cd9da9e6..365b14f85404 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h @@ -74,11 +74,11 @@ class Accelerator : public HWModule { /// Abstract class representing a connection to an accelerator. Actual /// connections (e.g. to a co-simulation or actual device) are implemented by -/// subclasses. +/// subclasses. No methods in here are thread safe. class AcceleratorConnection { public: AcceleratorConnection(Context &ctxt); - virtual ~AcceleratorConnection() = default; + virtual ~AcceleratorConnection(); Context &getCtxt() const { return ctxt; } /// Disconnect from the accelerator cleanly. @@ -89,7 +89,12 @@ class AcceleratorConnection { virtual std::map requestChannelsFor(AppIDPath, const BundleType *) = 0; - AcceleratorServiceThread *getServiceThread() { return serviceThread.get(); } + /// Return a pointer to the accelerator 'service' thread (or threads). If the + /// thread(s) are not running, they will be started when this method is + /// called. `std::thread` is used. If users don't want the runtime to spin up + /// threads, don't call this method. `AcceleratorServiceThread` is owned by + /// AcceleratorConnection and governed by the lifetime of the this object. + AcceleratorServiceThread *getServiceThread(); using Service = services::Service; /// Get a typed reference to a particular service type. Caller does *not* take @@ -109,6 +114,10 @@ class AcceleratorConnection { ServiceImplDetails details = {}, HWClientDetails clients = {}); + /// Assume ownership of an accelerator object. Ties the lifetime of the + /// accelerator to this connection. Returns a raw pointer to the object. + Accelerator *takeOwnership(std::unique_ptr accel); + protected: /// Called by `getServiceImpl` exclusively. It wraps the pointer returned by /// this in a unique_ptr and caches it. Separate this from the @@ -128,6 +137,10 @@ class AcceleratorConnection { std::map> serviceCache; std::unique_ptr serviceThread; + + /// List of accelerator objects owned by this connection. These are destroyed + /// when the connection dies or is shutdown. + std::vector> ownedAccelerators; }; namespace registry { @@ -173,6 +186,9 @@ class AcceleratorServiceThread { addListener(std::initializer_list listenPorts, std::function callback); + /// Poll this module. + void addPoll(HWModule &module); + /// Instruct the service thread to stop running. void stop(); diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h index 6d1713ce82d4..b95421e0cd01 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h @@ -77,6 +77,11 @@ class HWModule { return portIndex; } + /// Master poll method. Calls the `poll` method on all locally owned ports and + /// the master `poll` method on all of the children. Returns true if any of + /// the `poll` calls returns true. + bool poll(); + protected: const std::optional info; const std::vector> children; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h index 547ecb9dbe94..fccb6272ba13 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h @@ -48,9 +48,10 @@ class Manifest { // Modules which have designer specified metadata. std::vector getModuleInfos() const; - // Build a dynamic design hierarchy from the manifest. - std::unique_ptr - buildAccelerator(AcceleratorConnection &acc) const; + // Build a dynamic design hierarchy from the manifest. The + // AcceleratorConnection owns the returned pointer so its lifetime is + // determined by the connection. + Accelerator *buildAccelerator(AcceleratorConnection &acc) const; /// The Type Table is an ordered list of types. The offset can be used to /// compactly and uniquely within a design. It does not include all of the diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h index 2db68e76e5f9..c025cbfa085c 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h @@ -33,21 +33,38 @@ namespace esi { class ChannelPort { public: ChannelPort(const Type *type) : type(type) {} - virtual ~ChannelPort() { disconnect(); } + virtual ~ChannelPort() {} /// Set up a connection to the accelerator. The buffer size is optional and /// should be considered merely a hint. Individual implementations use it /// however they like. The unit is number of messages of the port type. - virtual void connect(std::optional bufferSize = std::nullopt) { - connectImpl(bufferSize); + virtual void connect(std::optional bufferSize = std::nullopt) = 0; + virtual void disconnect() = 0; + virtual bool isConnected() const = 0; + + /// Poll for incoming data. Returns true if data was read or written into a + /// buffer as a result of the poll. Calling the call back could (will) also + /// happen in that case. Some backends need this to be called periodically. In + /// the usual case, this will be called by a background thread, but the ESI + /// runtime does not want to assume that the host processes use standard + /// threads. If the user wants to provide their own threads, they need to call + /// this on each port occasionally. This is also called from the 'master' poll + /// method in the Accelerator class. + bool poll() { + if (isConnected()) + return pollImpl(); + return false; } - virtual void disconnect() {} const Type *getType() const { return type; } -private: +protected: const Type *type; + /// Method called by poll() to actually poll the channel if the channel is + /// connected. + virtual bool pollImpl() { return false; } + /// Called by all connect methods to let backends initiate the underlying /// connections. virtual void connectImpl(std::optional bufferSize) {} @@ -58,8 +75,19 @@ class WriteChannelPort : public ChannelPort { public: using ChannelPort::ChannelPort; + virtual void + connect(std::optional bufferSize = std::nullopt) override { + connectImpl(bufferSize); + connected = true; + } + virtual void disconnect() override { connected = false; } + virtual bool isConnected() const override { return connected; } + /// A very basic write API. Will likely change for performance reasons. virtual void write(const MessageData &) = 0; + +private: + volatile bool connected = false; }; /// A ChannelPort which reads data from the accelerator. It has two modes: @@ -72,6 +100,9 @@ class ReadChannelPort : public ChannelPort { ReadChannelPort(const Type *type) : ChannelPort(type), mode(Mode::Disconnected) {} virtual void disconnect() override { mode = Mode::Disconnected; } + virtual bool isConnected() const override { + return mode != Mode::Disconnected; + } //===--------------------------------------------------------------------===// // Callback mode: To use a callback, connect with a callback function which @@ -121,7 +152,7 @@ class ReadChannelPort : public ChannelPort { protected: /// Indicates the current mode of the channel. enum Mode { Disconnected, Callback, Polling }; - Mode mode; + volatile Mode mode; /// Backends call this callback when new data is available. std::function callback; @@ -178,6 +209,15 @@ class BundlePort { return const_cast(dynamic_cast(this)); } + /// Calls `poll` on all channels in the bundle and returns true if any of them + /// returned true. + bool poll() { + bool result = false; + for (auto &channel : channels) + result |= channel.second.poll(); + return result; + } + private: AppID id; std::map channels; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h index 6be52ffe6c85..ff5416bb4398 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h @@ -52,6 +52,7 @@ class TraceAccelerator : public esi::AcceleratorConnection { /// is opened for writing. For 'Read' mode, this file is opened for reading. TraceAccelerator(Context &, Mode mode, std::filesystem::path manifestJson, std::filesystem::path traceFile); + ~TraceAccelerator() override; /// Parse the connection string and instantiate the accelerator. Format is: /// ":[:]". diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h index 04454cb3c8a4..def6b4ebf8bd 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h @@ -33,6 +33,7 @@ class XrtAccelerator : public esi::AcceleratorConnection { struct Impl; XrtAccelerator(Context &, std::string xclbin, std::string kernelName); + ~XrtAccelerator(); static std::unique_ptr connect(Context &, std::string connectionString); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp index 0aae47341b6e..9f94d784489d 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp @@ -34,7 +34,14 @@ using namespace esi::services; namespace esi { AcceleratorConnection::AcceleratorConnection(Context &ctxt) - : ctxt(ctxt), serviceThread(std::make_unique()) {} + : ctxt(ctxt), serviceThread(nullptr) {} +AcceleratorConnection::~AcceleratorConnection() { disconnect(); } + +AcceleratorServiceThread *AcceleratorConnection::getServiceThread() { + if (!serviceThread) + serviceThread = std::make_unique(); + return serviceThread.get(); +} services::Service *AcceleratorConnection::getService(Service::Type svcType, AppIDPath id, @@ -54,6 +61,13 @@ services::Service *AcceleratorConnection::getService(Service::Type svcType, return cacheEntry.get(); } +Accelerator * +AcceleratorConnection::takeOwnership(std::unique_ptr acc) { + Accelerator *ret = acc.get(); + ownedAccelerators.push_back(std::move(acc)); + return ret; +} + /// Get the path to the currently running executable. static std::filesystem::path getExePath() { #ifdef __linux__ @@ -224,18 +238,27 @@ struct AcceleratorServiceThread::Impl { addListener(std::initializer_list listenPorts, std::function callback); + void addTask(std::function task) { + std::lock_guard g(m); + taskList.push_back(task); + } + private: void loop(); volatile bool shutdown = false; std::thread me; - // Protect the listeners std::map. - std::mutex listenerMutex; + // Protect the shared data structures. + std::mutex m; + // Map of read ports to callbacks. std::map, std::future>> listeners; + + /// Tasks which should be called on every loop iteration. + std::vector> taskList; }; void AcceleratorServiceThread::Impl::loop() { @@ -245,6 +268,7 @@ void AcceleratorServiceThread::Impl::loop() { std::function, MessageData>> portUnlockWorkList; + std::vector> taskListCopy; MessageData data; while (!shutdown) { @@ -256,7 +280,7 @@ void AcceleratorServiceThread::Impl::loop() { // Check and gather data from all the read ports we are monitoring. Put the // callbacks to be called later so we can release the lock. { - std::lock_guard g(listenerMutex); + std::lock_guard g(m); for (auto &[channel, cbfPair] : listeners) { assert(channel && "Null channel in listener list"); std::future &f = cbfPair.second; @@ -273,13 +297,22 @@ void AcceleratorServiceThread::Impl::loop() { // Clear the worklist for the next iteration. portUnlockWorkList.clear(); + + // Call any tasks that have been added. Copy it first so we can release the + // lock ASAP. + { + std::lock_guard g(m); + taskListCopy = taskList; + } + for (auto &task : taskListCopy) + task(); } } void AcceleratorServiceThread::Impl::addListener( std::initializer_list listenPorts, std::function callback) { - std::lock_guard g(listenerMutex); + std::lock_guard g(m); for (auto port : listenPorts) { if (listeners.count(port)) throw std::runtime_error("Port already has a listener"); @@ -312,6 +345,11 @@ void AcceleratorServiceThread::addListener( impl->addListener(listenPorts, callback); } +void AcceleratorServiceThread::addPoll(HWModule &module) { + assert(impl && "Service thread not running"); + impl->addTask([&module]() { module.poll(); }); +} + void AcceleratorConnection::disconnect() { if (serviceThread) { serviceThread->stop(); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp index 8e5bd5e45f0d..9e54e694e449 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp @@ -47,4 +47,13 @@ HWModule::HWModule(std::optional info, childIndex(buildIndex(this->children)), services(services), ports(std::move(ports)), portIndex(buildIndex(this->ports)) {} +bool HWModule::poll() { + bool result = false; + for (auto &port : ports) + result |= port->poll(); + for (auto &child : children) + result |= child->poll(); + return result; +} + } // namespace esi diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 7e4fd84c5034..62d16feec419 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -534,9 +534,8 @@ std::vector Manifest::getModuleInfos() const { return ret; } -std::unique_ptr -Manifest::buildAccelerator(AcceleratorConnection &acc) const { - return impl->buildAccelerator(acc); +Accelerator *Manifest::buildAccelerator(AcceleratorConnection &acc) const { + return acc.takeOwnership(impl->buildAccelerator(acc)); } const std::vector &Manifest::getTypeTable() const { diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp index 29f858defa16..cd5fd6596811 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp @@ -47,7 +47,7 @@ void ReadChannelPort::connect(std::function callback, throw std::runtime_error("Channel already connected"); mode = Mode::Callback; this->callback = callback; - ChannelPort::connect(bufferSize); + connectImpl(bufferSize); } void ReadChannelPort::connect(std::optional bufferSize) { @@ -71,7 +71,7 @@ void ReadChannelPort::connect(std::optional bufferSize) { } return true; }; - ChannelPort::connect(bufferSize); + connectImpl(bufferSize); } std::future ReadChannelPort::readAsync() { diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index 86a4685fd58b..42cc8bc25150 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -126,6 +126,7 @@ CosimAccelerator::CosimAccelerator(Context &ctxt, std::string hostname, rpcClient = new StubContainer(ChannelServer::NewStub(channel)); } CosimAccelerator::~CosimAccelerator() { + disconnect(); if (rpcClient) delete rpcClient; channels.clear(); @@ -418,23 +419,23 @@ class CosimHostMem : public HostMem { this->size = size; } virtual ~CosimHostMemRegion() { free(ptr); } - virtual void *getPtr() const { return ptr; } - virtual std::size_t getSize() const { return size; } + virtual void *getPtr() const override { return ptr; } + virtual std::size_t getSize() const override { return size; } private: void *ptr; std::size_t size; }; - virtual std::unique_ptr allocate(std::size_t size, - HostMem::Options opts) const { + virtual std::unique_ptr + allocate(std::size_t size, HostMem::Options opts) const override { return std::unique_ptr(new CosimHostMemRegion(size)); } virtual bool mapMemory(void *ptr, std::size_t size, - HostMem::Options opts) const { + HostMem::Options opts) const override { return true; } - virtual void unmapMemory(void *ptr) const {} + virtual void unmapMemory(void *ptr) const override {} }; } // namespace diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp index b93df860028d..a69f657b75df 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp @@ -135,6 +135,7 @@ TraceAccelerator::TraceAccelerator(Context &ctxt, Mode mode, : AcceleratorConnection(ctxt) { impl = std::make_unique(mode, manifestJson, traceFile); } +TraceAccelerator::~TraceAccelerator() { disconnect(); } Service *TraceAccelerator::createService(Service::Type svcType, AppIDPath idPath, std::string implName, @@ -197,22 +198,7 @@ class ReadTraceChannelPort : public ReadChannelPort { : ReadChannelPort(type) {} ~ReadTraceChannelPort() { disconnect(); } - void disconnect() override { - ReadChannelPort::disconnect(); - if (!dataPushThread.joinable()) - return; - shutdown = true; - shutdownCV.notify_all(); - dataPushThread.join(); - } - private: - void connectImpl(std::optional bufferSize) override { - assert(!dataPushThread.joinable() && "already connected"); - shutdown = false; - dataPushThread = std::thread(&ReadTraceChannelPort::dataPushLoop, this); - } - MessageData genMessage() { std::ptrdiff_t numBits = getType()->getBitWidth(); if (numBits < 0) @@ -227,19 +213,7 @@ class ReadTraceChannelPort : public ReadChannelPort { return MessageData(bytes); } - void dataPushLoop() { - std::mutex m; - std::unique_lock lock(m); - while (!shutdown) { - shutdownCV.wait_for(lock, std::chrono::milliseconds(100)); - while (this->callback(genMessage())) - shutdownCV.wait_for(lock, std::chrono::milliseconds(10)); - } - } - - std::thread dataPushThread; - std::condition_variable shutdownCV; - std::atomic shutdown; + bool pollImpl() override { return callback(genMessage()); } }; } // namespace @@ -289,8 +263,8 @@ class TraceHostMem : public HostMem { impl.write("HostMem") << "free " << ptr << std::endl; free(ptr); } - virtual void *getPtr() const { return ptr; } - virtual std::size_t getSize() const { return size; } + virtual void *getPtr() const override { return ptr; } + virtual std::size_t getSize() const override { return size; } private: void *ptr; @@ -298,8 +272,8 @@ class TraceHostMem : public HostMem { TraceAccelerator::Impl &impl; }; - virtual std::unique_ptr allocate(std::size_t size, - HostMem::Options opts) const { + virtual std::unique_ptr + allocate(std::size_t size, HostMem::Options opts) const override { auto ret = std::unique_ptr(new TraceHostMemRegion(size, impl)); impl.write("HostMem 0x") @@ -309,14 +283,14 @@ class TraceHostMem : public HostMem { return ret; } virtual bool mapMemory(void *ptr, std::size_t size, - HostMem::Options opts) const { + HostMem::Options opts) const override { impl.write("HostMem") << "map 0x" << ptr << " size " << size << " bytes. Writeable: " << opts.writeable << ", useLargePages: " << opts.useLargePages << std::endl; return true; } - virtual void unmapMemory(void *ptr) const { + virtual void unmapMemory(void *ptr) const override { impl.write("HostMem") << "unmap 0x" << ptr << std::endl; } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp index 694f069852c5..26b7ada78384 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp @@ -78,6 +78,7 @@ XrtAccelerator::XrtAccelerator(Context &ctxt, std::string xclbin, : AcceleratorConnection(ctxt) { impl = make_unique(xclbin, device_id); } +XrtAccelerator::~XrtAccelerator() { disconnect(); } namespace { class XrtMMIO : public MMIO { diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp index a093988cf1bd..450cabece943 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp @@ -121,10 +121,10 @@ void printInstance(std::ostream &os, const HWModule *d, void printHier(std::ostream &os, AcceleratorConnection &acc) { Manifest manifest(acc.getCtxt(), acc.getService()->getJsonManifest()); - std::unique_ptr design = manifest.buildAccelerator(acc); + Accelerator *design = manifest.buildAccelerator(acc); os << "********************************" << std::endl; os << "* Design hierarchy" << std::endl; os << "********************************" << std::endl; os << std::endl; - printInstance(os, design.get()); + printInstance(os, design); } diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index fa34d4cad4be..d2563616d645 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -50,9 +50,10 @@ int main(int argc, const char *argv[]) { std::unique_ptr acc = ctxt.connect(backend, conn); const auto &info = *acc->getService(); Manifest manifest(ctxt, info.getJsonManifest()); - std::unique_ptr accel = manifest.buildAccelerator(*acc); + Accelerator *accel = manifest.buildAccelerator(*acc); + acc->getServiceThread()->addPoll(*accel); - registerCallbacks(accel.get()); + registerCallbacks(accel); if (cmd == "loop") { while (true) { diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp index c73df6872ca5..08061b77e84a 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp +++ b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp @@ -270,8 +270,14 @@ PYBIND11_MODULE(esiCppAccel, m) { py::class_(m, "Manifest") .def(py::init()) .def_property_readonly("api_version", &Manifest::getApiVersion) - .def("build_accelerator", &Manifest::buildAccelerator, - py::return_value_policy::take_ownership) + .def( + "build_accelerator", + [&](Manifest &m, AcceleratorConnection &conn) { + auto acc = m.buildAccelerator(conn); + conn.getServiceThread()->addPoll(*acc); + return acc; + }, + py::return_value_policy::reference) .def_property_readonly("type_table", &Manifest::getTypeTable) .def_property_readonly("module_infos", &Manifest::getModuleInfos); }