@@ -40,7 +40,7 @@ KernelCode* py_createKernelCode(const std::string &pData, size_t workgroupSize,
40
40
return new KernelCode (pData, workgroupSize, (NumType)precision);
41
41
}
42
42
43
- Kernel* py_createKernel (Context *ctx, const KernelCode *code,
43
+ Kernel py_createKernel (Context *ctx, const KernelCode *code,
44
44
// const Tensor *dataBindings, size_t numTensors,
45
45
const py::list& dataBindings_py,
46
46
// const size_t *viewOffsets,
@@ -54,7 +54,7 @@ Kernel* py_createKernel(Context *ctx, const KernelCode *code,
54
54
for (auto item : viewOffsets_py) {
55
55
viewOffsets.push_back (item.cast <size_t >());
56
56
}
57
- return new Kernel ( createKernel (*ctx, *code, bindings.data (), bindings.size (), viewOffsets.data (), vector_to_shape (totalWorkgroups) ));
57
+ return createKernel (*ctx, *code, bindings.data (), bindings.size (), viewOffsets.data (), vector_to_shape (totalWorkgroups));
58
58
}
59
59
60
60
Tensor* py_createTensor (Context *ctx, const std::vector<int > &dims, int dtype) {
@@ -82,9 +82,9 @@ struct GpuAsync {
82
82
}
83
83
};
84
84
85
- GpuAsync* py_dispatchKernel (Context *ctx, Kernel * kernel) {
85
+ GpuAsync* py_dispatchKernel (Context *ctx, Kernel kernel) {
86
86
auto async = new GpuAsync ();
87
- dispatchKernel (*ctx, * kernel, async->promise );
87
+ dispatchKernel (*ctx, kernel, async->promise );
88
88
return async;
89
89
}
90
90
@@ -96,12 +96,12 @@ PYBIND11_MODULE(gpu_cpp, m) {
96
96
m.doc () = " gpu.cpp plugin" ;
97
97
py::class_<Context>(m, " Context" );
98
98
py::class_<Tensor>(m, " Tensor" );
99
- py::class_<Kernel >(m, " Kernel" );
99
+ py::class_<RawKernel, std::shared_ptr<RawKernel> >(m, " Kernel" );
100
100
py::class_<KernelCode>(m, " KernelCode" );
101
101
py::class_<GpuAsync>(m, " GpuAsync" );
102
102
m.def (" create_context" , &py_createContext, py::return_value_policy::take_ownership);
103
103
m.def (" create_tensor" , &py_createTensor, py::return_value_policy::take_ownership);
104
- m.def (" create_kernel" , &py_createKernel, py::return_value_policy::take_ownership );
104
+ m.def (" create_kernel" , &py_createKernel);
105
105
m.def (" create_kernel_code" , &py_createKernelCode, py::return_value_policy::take_ownership);
106
106
m.def (" dispatch_kernel" , &py_dispatchKernel, py::return_value_policy::take_ownership);
107
107
m.def (" wait" , &py_wait, " Wait for GPU" );
0 commit comments