From ad913784651c64f72306b335eedc4f7feaf71372 Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 8 Aug 2024 12:53:42 +0200 Subject: [PATCH] [ESI][Runtime] Poll method and optional service thread polling (#7460) Add a poll method to ports, a master poll method to the Accelerator, and the ability to poll from the service thread. Also, only spin up the service thread if it's requested. The service thread polling (in particular) required some ownership changes: Accelerator objects now belong to the AcceleratorConnection so that the ports aren't destructed before the service thread gets shutdown (which causes an invalid memory access). This particular binding isn't ideal, is brittle, and will be an issue for anything doing the polling. Resolving #7457 should mitigate this issue. Backends are now _required_ to call `disconnect` in their destructor. --- .../Dialect/ESI/runtime/loopback.mlir.py | 3 ++ .../ESI/runtime/cpp/include/esi/Accelerator.h | 22 ++++++-- .../ESI/runtime/cpp/include/esi/Design.h | 5 ++ .../ESI/runtime/cpp/include/esi/Manifest.h | 7 +-- .../ESI/runtime/cpp/include/esi/Ports.h | 52 ++++++++++++++++--- .../runtime/cpp/include/esi/backends/Trace.h | 1 + .../runtime/cpp/include/esi/backends/Xrt.h | 1 + .../ESI/runtime/cpp/lib/Accelerator.cpp | 48 +++++++++++++++-- lib/Dialect/ESI/runtime/cpp/lib/Design.cpp | 9 ++++ lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp | 5 +- lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp | 4 +- .../ESI/runtime/cpp/lib/backends/Cosim.cpp | 13 ++--- .../ESI/runtime/cpp/lib/backends/Trace.cpp | 42 +++------------ .../ESI/runtime/cpp/lib/backends/Xrt.cpp | 1 + .../ESI/runtime/cpp/tools/esiquery.cpp | 4 +- .../ESI/runtime/cpp/tools/esitester.cpp | 5 +- .../runtime/python/esiaccel/esiCppAccel.cpp | 10 +++- 17 files changed, 164 insertions(+), 68 deletions(-) diff --git a/integration_test/Dialect/ESI/runtime/loopback.mlir.py b/integration_test/Dialect/ESI/runtime/loopback.mlir.py index 8bf332af4473..a942a45b232b 100644 --- a/integration_test/Dialect/ESI/runtime/loopback.mlir.py +++ b/integration_test/Dialect/ESI/runtime/loopback.mlir.py @@ -94,4 +94,7 @@ print(f"result: {result}") if platform != "trace": assert result == [-21, -22] + +acc = None + print("PASS") diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h index c3a1cd9da9e6..365b14f85404 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h @@ -74,11 +74,11 @@ class Accelerator : public HWModule { /// Abstract class representing a connection to an accelerator. Actual /// connections (e.g. to a co-simulation or actual device) are implemented by -/// subclasses. +/// subclasses. No methods in here are thread safe. class AcceleratorConnection { public: AcceleratorConnection(Context &ctxt); - virtual ~AcceleratorConnection() = default; + virtual ~AcceleratorConnection(); Context &getCtxt() const { return ctxt; } /// Disconnect from the accelerator cleanly. @@ -89,7 +89,12 @@ class AcceleratorConnection { virtual std::map requestChannelsFor(AppIDPath, const BundleType *) = 0; - AcceleratorServiceThread *getServiceThread() { return serviceThread.get(); } + /// Return a pointer to the accelerator 'service' thread (or threads). If the + /// thread(s) are not running, they will be started when this method is + /// called. `std::thread` is used. If users don't want the runtime to spin up + /// threads, don't call this method. `AcceleratorServiceThread` is owned by + /// AcceleratorConnection and governed by the lifetime of the this object. + AcceleratorServiceThread *getServiceThread(); using Service = services::Service; /// Get a typed reference to a particular service type. Caller does *not* take @@ -109,6 +114,10 @@ class AcceleratorConnection { ServiceImplDetails details = {}, HWClientDetails clients = {}); + /// Assume ownership of an accelerator object. Ties the lifetime of the + /// accelerator to this connection. Returns a raw pointer to the object. + Accelerator *takeOwnership(std::unique_ptr accel); + protected: /// Called by `getServiceImpl` exclusively. It wraps the pointer returned by /// this in a unique_ptr and caches it. Separate this from the @@ -128,6 +137,10 @@ class AcceleratorConnection { std::map> serviceCache; std::unique_ptr serviceThread; + + /// List of accelerator objects owned by this connection. These are destroyed + /// when the connection dies or is shutdown. + std::vector> ownedAccelerators; }; namespace registry { @@ -173,6 +186,9 @@ class AcceleratorServiceThread { addListener(std::initializer_list listenPorts, std::function callback); + /// Poll this module. + void addPoll(HWModule &module); + /// Instruct the service thread to stop running. void stop(); diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h index 6d1713ce82d4..b95421e0cd01 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h @@ -77,6 +77,11 @@ class HWModule { return portIndex; } + /// Master poll method. Calls the `poll` method on all locally owned ports and + /// the master `poll` method on all of the children. Returns true if any of + /// the `poll` calls returns true. + bool poll(); + protected: const std::optional info; const std::vector> children; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h index 547ecb9dbe94..fccb6272ba13 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h @@ -48,9 +48,10 @@ class Manifest { // Modules which have designer specified metadata. std::vector getModuleInfos() const; - // Build a dynamic design hierarchy from the manifest. - std::unique_ptr - buildAccelerator(AcceleratorConnection &acc) const; + // Build a dynamic design hierarchy from the manifest. The + // AcceleratorConnection owns the returned pointer so its lifetime is + // determined by the connection. + Accelerator *buildAccelerator(AcceleratorConnection &acc) const; /// The Type Table is an ordered list of types. The offset can be used to /// compactly and uniquely within a design. It does not include all of the diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h index 2db68e76e5f9..c025cbfa085c 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h @@ -33,21 +33,38 @@ namespace esi { class ChannelPort { public: ChannelPort(const Type *type) : type(type) {} - virtual ~ChannelPort() { disconnect(); } + virtual ~ChannelPort() {} /// Set up a connection to the accelerator. The buffer size is optional and /// should be considered merely a hint. Individual implementations use it /// however they like. The unit is number of messages of the port type. - virtual void connect(std::optional bufferSize = std::nullopt) { - connectImpl(bufferSize); + virtual void connect(std::optional bufferSize = std::nullopt) = 0; + virtual void disconnect() = 0; + virtual bool isConnected() const = 0; + + /// Poll for incoming data. Returns true if data was read or written into a + /// buffer as a result of the poll. Calling the call back could (will) also + /// happen in that case. Some backends need this to be called periodically. In + /// the usual case, this will be called by a background thread, but the ESI + /// runtime does not want to assume that the host processes use standard + /// threads. If the user wants to provide their own threads, they need to call + /// this on each port occasionally. This is also called from the 'master' poll + /// method in the Accelerator class. + bool poll() { + if (isConnected()) + return pollImpl(); + return false; } - virtual void disconnect() {} const Type *getType() const { return type; } -private: +protected: const Type *type; + /// Method called by poll() to actually poll the channel if the channel is + /// connected. + virtual bool pollImpl() { return false; } + /// Called by all connect methods to let backends initiate the underlying /// connections. virtual void connectImpl(std::optional bufferSize) {} @@ -58,8 +75,19 @@ class WriteChannelPort : public ChannelPort { public: using ChannelPort::ChannelPort; + virtual void + connect(std::optional bufferSize = std::nullopt) override { + connectImpl(bufferSize); + connected = true; + } + virtual void disconnect() override { connected = false; } + virtual bool isConnected() const override { return connected; } + /// A very basic write API. Will likely change for performance reasons. virtual void write(const MessageData &) = 0; + +private: + volatile bool connected = false; }; /// A ChannelPort which reads data from the accelerator. It has two modes: @@ -72,6 +100,9 @@ class ReadChannelPort : public ChannelPort { ReadChannelPort(const Type *type) : ChannelPort(type), mode(Mode::Disconnected) {} virtual void disconnect() override { mode = Mode::Disconnected; } + virtual bool isConnected() const override { + return mode != Mode::Disconnected; + } //===--------------------------------------------------------------------===// // Callback mode: To use a callback, connect with a callback function which @@ -121,7 +152,7 @@ class ReadChannelPort : public ChannelPort { protected: /// Indicates the current mode of the channel. enum Mode { Disconnected, Callback, Polling }; - Mode mode; + volatile Mode mode; /// Backends call this callback when new data is available. std::function callback; @@ -178,6 +209,15 @@ class BundlePort { return const_cast(dynamic_cast(this)); } + /// Calls `poll` on all channels in the bundle and returns true if any of them + /// returned true. + bool poll() { + bool result = false; + for (auto &channel : channels) + result |= channel.second.poll(); + return result; + } + private: AppID id; std::map channels; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h index 6be52ffe6c85..ff5416bb4398 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h @@ -52,6 +52,7 @@ class TraceAccelerator : public esi::AcceleratorConnection { /// is opened for writing. For 'Read' mode, this file is opened for reading. TraceAccelerator(Context &, Mode mode, std::filesystem::path manifestJson, std::filesystem::path traceFile); + ~TraceAccelerator() override; /// Parse the connection string and instantiate the accelerator. Format is: /// ":[:]". diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h index 04454cb3c8a4..def6b4ebf8bd 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h @@ -33,6 +33,7 @@ class XrtAccelerator : public esi::AcceleratorConnection { struct Impl; XrtAccelerator(Context &, std::string xclbin, std::string kernelName); + ~XrtAccelerator(); static std::unique_ptr connect(Context &, std::string connectionString); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp index 0aae47341b6e..9f94d784489d 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp @@ -34,7 +34,14 @@ using namespace esi::services; namespace esi { AcceleratorConnection::AcceleratorConnection(Context &ctxt) - : ctxt(ctxt), serviceThread(std::make_unique()) {} + : ctxt(ctxt), serviceThread(nullptr) {} +AcceleratorConnection::~AcceleratorConnection() { disconnect(); } + +AcceleratorServiceThread *AcceleratorConnection::getServiceThread() { + if (!serviceThread) + serviceThread = std::make_unique(); + return serviceThread.get(); +} services::Service *AcceleratorConnection::getService(Service::Type svcType, AppIDPath id, @@ -54,6 +61,13 @@ services::Service *AcceleratorConnection::getService(Service::Type svcType, return cacheEntry.get(); } +Accelerator * +AcceleratorConnection::takeOwnership(std::unique_ptr acc) { + Accelerator *ret = acc.get(); + ownedAccelerators.push_back(std::move(acc)); + return ret; +} + /// Get the path to the currently running executable. static std::filesystem::path getExePath() { #ifdef __linux__ @@ -224,18 +238,27 @@ struct AcceleratorServiceThread::Impl { addListener(std::initializer_list listenPorts, std::function callback); + void addTask(std::function task) { + std::lock_guard g(m); + taskList.push_back(task); + } + private: void loop(); volatile bool shutdown = false; std::thread me; - // Protect the listeners std::map. - std::mutex listenerMutex; + // Protect the shared data structures. + std::mutex m; + // Map of read ports to callbacks. std::map, std::future>> listeners; + + /// Tasks which should be called on every loop iteration. + std::vector> taskList; }; void AcceleratorServiceThread::Impl::loop() { @@ -245,6 +268,7 @@ void AcceleratorServiceThread::Impl::loop() { std::function, MessageData>> portUnlockWorkList; + std::vector> taskListCopy; MessageData data; while (!shutdown) { @@ -256,7 +280,7 @@ void AcceleratorServiceThread::Impl::loop() { // Check and gather data from all the read ports we are monitoring. Put the // callbacks to be called later so we can release the lock. { - std::lock_guard g(listenerMutex); + std::lock_guard g(m); for (auto &[channel, cbfPair] : listeners) { assert(channel && "Null channel in listener list"); std::future &f = cbfPair.second; @@ -273,13 +297,22 @@ void AcceleratorServiceThread::Impl::loop() { // Clear the worklist for the next iteration. portUnlockWorkList.clear(); + + // Call any tasks that have been added. Copy it first so we can release the + // lock ASAP. + { + std::lock_guard g(m); + taskListCopy = taskList; + } + for (auto &task : taskListCopy) + task(); } } void AcceleratorServiceThread::Impl::addListener( std::initializer_list listenPorts, std::function callback) { - std::lock_guard g(listenerMutex); + std::lock_guard g(m); for (auto port : listenPorts) { if (listeners.count(port)) throw std::runtime_error("Port already has a listener"); @@ -312,6 +345,11 @@ void AcceleratorServiceThread::addListener( impl->addListener(listenPorts, callback); } +void AcceleratorServiceThread::addPoll(HWModule &module) { + assert(impl && "Service thread not running"); + impl->addTask([&module]() { module.poll(); }); +} + void AcceleratorConnection::disconnect() { if (serviceThread) { serviceThread->stop(); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp index 8e5bd5e45f0d..9e54e694e449 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Design.cpp @@ -47,4 +47,13 @@ HWModule::HWModule(std::optional info, childIndex(buildIndex(this->children)), services(services), ports(std::move(ports)), portIndex(buildIndex(this->ports)) {} +bool HWModule::poll() { + bool result = false; + for (auto &port : ports) + result |= port->poll(); + for (auto &child : children) + result |= child->poll(); + return result; +} + } // namespace esi diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 7e4fd84c5034..62d16feec419 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -534,9 +534,8 @@ std::vector Manifest::getModuleInfos() const { return ret; } -std::unique_ptr -Manifest::buildAccelerator(AcceleratorConnection &acc) const { - return impl->buildAccelerator(acc); +Accelerator *Manifest::buildAccelerator(AcceleratorConnection &acc) const { + return acc.takeOwnership(impl->buildAccelerator(acc)); } const std::vector &Manifest::getTypeTable() const { diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp index 29f858defa16..cd5fd6596811 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp @@ -47,7 +47,7 @@ void ReadChannelPort::connect(std::function callback, throw std::runtime_error("Channel already connected"); mode = Mode::Callback; this->callback = callback; - ChannelPort::connect(bufferSize); + connectImpl(bufferSize); } void ReadChannelPort::connect(std::optional bufferSize) { @@ -71,7 +71,7 @@ void ReadChannelPort::connect(std::optional bufferSize) { } return true; }; - ChannelPort::connect(bufferSize); + connectImpl(bufferSize); } std::future ReadChannelPort::readAsync() { diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index 86a4685fd58b..42cc8bc25150 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -126,6 +126,7 @@ CosimAccelerator::CosimAccelerator(Context &ctxt, std::string hostname, rpcClient = new StubContainer(ChannelServer::NewStub(channel)); } CosimAccelerator::~CosimAccelerator() { + disconnect(); if (rpcClient) delete rpcClient; channels.clear(); @@ -418,23 +419,23 @@ class CosimHostMem : public HostMem { this->size = size; } virtual ~CosimHostMemRegion() { free(ptr); } - virtual void *getPtr() const { return ptr; } - virtual std::size_t getSize() const { return size; } + virtual void *getPtr() const override { return ptr; } + virtual std::size_t getSize() const override { return size; } private: void *ptr; std::size_t size; }; - virtual std::unique_ptr allocate(std::size_t size, - HostMem::Options opts) const { + virtual std::unique_ptr + allocate(std::size_t size, HostMem::Options opts) const override { return std::unique_ptr(new CosimHostMemRegion(size)); } virtual bool mapMemory(void *ptr, std::size_t size, - HostMem::Options opts) const { + HostMem::Options opts) const override { return true; } - virtual void unmapMemory(void *ptr) const {} + virtual void unmapMemory(void *ptr) const override {} }; } // namespace diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp index b93df860028d..a69f657b75df 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp @@ -135,6 +135,7 @@ TraceAccelerator::TraceAccelerator(Context &ctxt, Mode mode, : AcceleratorConnection(ctxt) { impl = std::make_unique(mode, manifestJson, traceFile); } +TraceAccelerator::~TraceAccelerator() { disconnect(); } Service *TraceAccelerator::createService(Service::Type svcType, AppIDPath idPath, std::string implName, @@ -197,22 +198,7 @@ class ReadTraceChannelPort : public ReadChannelPort { : ReadChannelPort(type) {} ~ReadTraceChannelPort() { disconnect(); } - void disconnect() override { - ReadChannelPort::disconnect(); - if (!dataPushThread.joinable()) - return; - shutdown = true; - shutdownCV.notify_all(); - dataPushThread.join(); - } - private: - void connectImpl(std::optional bufferSize) override { - assert(!dataPushThread.joinable() && "already connected"); - shutdown = false; - dataPushThread = std::thread(&ReadTraceChannelPort::dataPushLoop, this); - } - MessageData genMessage() { std::ptrdiff_t numBits = getType()->getBitWidth(); if (numBits < 0) @@ -227,19 +213,7 @@ class ReadTraceChannelPort : public ReadChannelPort { return MessageData(bytes); } - void dataPushLoop() { - std::mutex m; - std::unique_lock lock(m); - while (!shutdown) { - shutdownCV.wait_for(lock, std::chrono::milliseconds(100)); - while (this->callback(genMessage())) - shutdownCV.wait_for(lock, std::chrono::milliseconds(10)); - } - } - - std::thread dataPushThread; - std::condition_variable shutdownCV; - std::atomic shutdown; + bool pollImpl() override { return callback(genMessage()); } }; } // namespace @@ -289,8 +263,8 @@ class TraceHostMem : public HostMem { impl.write("HostMem") << "free " << ptr << std::endl; free(ptr); } - virtual void *getPtr() const { return ptr; } - virtual std::size_t getSize() const { return size; } + virtual void *getPtr() const override { return ptr; } + virtual std::size_t getSize() const override { return size; } private: void *ptr; @@ -298,8 +272,8 @@ class TraceHostMem : public HostMem { TraceAccelerator::Impl &impl; }; - virtual std::unique_ptr allocate(std::size_t size, - HostMem::Options opts) const { + virtual std::unique_ptr + allocate(std::size_t size, HostMem::Options opts) const override { auto ret = std::unique_ptr(new TraceHostMemRegion(size, impl)); impl.write("HostMem 0x") @@ -309,14 +283,14 @@ class TraceHostMem : public HostMem { return ret; } virtual bool mapMemory(void *ptr, std::size_t size, - HostMem::Options opts) const { + HostMem::Options opts) const override { impl.write("HostMem") << "map 0x" << ptr << " size " << size << " bytes. Writeable: " << opts.writeable << ", useLargePages: " << opts.useLargePages << std::endl; return true; } - virtual void unmapMemory(void *ptr) const { + virtual void unmapMemory(void *ptr) const override { impl.write("HostMem") << "unmap 0x" << ptr << std::endl; } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp index 694f069852c5..26b7ada78384 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp @@ -78,6 +78,7 @@ XrtAccelerator::XrtAccelerator(Context &ctxt, std::string xclbin, : AcceleratorConnection(ctxt) { impl = make_unique(xclbin, device_id); } +XrtAccelerator::~XrtAccelerator() { disconnect(); } namespace { class XrtMMIO : public MMIO { diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp index a093988cf1bd..450cabece943 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esiquery.cpp @@ -121,10 +121,10 @@ void printInstance(std::ostream &os, const HWModule *d, void printHier(std::ostream &os, AcceleratorConnection &acc) { Manifest manifest(acc.getCtxt(), acc.getService()->getJsonManifest()); - std::unique_ptr design = manifest.buildAccelerator(acc); + Accelerator *design = manifest.buildAccelerator(acc); os << "********************************" << std::endl; os << "* Design hierarchy" << std::endl; os << "********************************" << std::endl; os << std::endl; - printInstance(os, design.get()); + printInstance(os, design); } diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index fa34d4cad4be..d2563616d645 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -50,9 +50,10 @@ int main(int argc, const char *argv[]) { std::unique_ptr acc = ctxt.connect(backend, conn); const auto &info = *acc->getService(); Manifest manifest(ctxt, info.getJsonManifest()); - std::unique_ptr accel = manifest.buildAccelerator(*acc); + Accelerator *accel = manifest.buildAccelerator(*acc); + acc->getServiceThread()->addPoll(*accel); - registerCallbacks(accel.get()); + registerCallbacks(accel); if (cmd == "loop") { while (true) { diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp index c73df6872ca5..08061b77e84a 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp +++ b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp @@ -270,8 +270,14 @@ PYBIND11_MODULE(esiCppAccel, m) { py::class_(m, "Manifest") .def(py::init()) .def_property_readonly("api_version", &Manifest::getApiVersion) - .def("build_accelerator", &Manifest::buildAccelerator, - py::return_value_policy::take_ownership) + .def( + "build_accelerator", + [&](Manifest &m, AcceleratorConnection &conn) { + auto acc = m.buildAccelerator(conn); + conn.getServiceThread()->addPoll(*acc); + return acc; + }, + py::return_value_policy::reference) .def_property_readonly("type_table", &Manifest::getTypeTable) .def_property_readonly("module_infos", &Manifest::getModuleInfos); }