@@ -40,7 +40,7 @@ KernelCode* py_createKernelCode(const std::string &pData, size_t workgroupSize,
4040 return new KernelCode (pData, workgroupSize, (NumType)precision);
4141}
4242
43- Kernel* py_createKernel (Context *ctx, const KernelCode *code,
43+ Kernel py_createKernel (Context *ctx, const KernelCode *code,
4444 // const Tensor *dataBindings, size_t numTensors,
4545 const py::list& dataBindings_py,
4646 // const size_t *viewOffsets,
@@ -54,7 +54,7 @@ Kernel* py_createKernel(Context *ctx, const KernelCode *code,
5454 for (auto item : viewOffsets_py) {
5555 viewOffsets.push_back (item.cast <size_t >());
5656 }
57- return new Kernel ( createKernel (*ctx, *code, bindings.data (), bindings.size (), viewOffsets.data (), vector_to_shape (totalWorkgroups) ));
57+ return createKernel (*ctx, *code, bindings.data (), bindings.size (), viewOffsets.data (), vector_to_shape (totalWorkgroups));
5858}
5959
6060Tensor* py_createTensor (Context *ctx, const std::vector<int > &dims, int dtype) {
@@ -82,9 +82,9 @@ struct GpuAsync {
8282 }
8383};
8484
85- GpuAsync* py_dispatchKernel (Context *ctx, Kernel * kernel) {
85+ GpuAsync* py_dispatchKernel (Context *ctx, Kernel kernel) {
8686 auto async = new GpuAsync ();
87- dispatchKernel (*ctx, * kernel, async->promise );
87+ dispatchKernel (*ctx, kernel, async->promise );
8888 return async;
8989}
9090
@@ -96,12 +96,12 @@ PYBIND11_MODULE(gpu_cpp, m) {
9696 m.doc () = " gpu.cpp plugin" ;
9797 py::class_<Context>(m, " Context" );
9898 py::class_<Tensor>(m, " Tensor" );
99- py::class_<Kernel >(m, " Kernel" );
99+ py::class_<RawKernel, std::shared_ptr<RawKernel> >(m, " Kernel" );
100100 py::class_<KernelCode>(m, " KernelCode" );
101101 py::class_<GpuAsync>(m, " GpuAsync" );
102102 m.def (" create_context" , &py_createContext, py::return_value_policy::take_ownership);
103103 m.def (" create_tensor" , &py_createTensor, py::return_value_policy::take_ownership);
104- m.def (" create_kernel" , &py_createKernel, py::return_value_policy::take_ownership );
104+ m.def (" create_kernel" , &py_createKernel);
105105 m.def (" create_kernel_code" , &py_createKernelCode, py::return_value_policy::take_ownership);
106106 m.def (" dispatch_kernel" , &py_dispatchKernel, py::return_value_policy::take_ownership);
107107 m.def (" wait" , &py_wait, " Wait for GPU" );
0 commit comments