-
Notifications
You must be signed in to change notification settings - Fork 22
Implement dpnp.interp()
#2417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Implement dpnp.interp()
#2417
Changes from 30 commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
a24d367
Initial impl of dpnp.inter()
vlad-perevezentsev e2b20b0
Second impl with dispatch_vector[only floating]
vlad-perevezentsev f7d1da9
Implement interpolate_complex
vlad-perevezentsev e1b8698
Move interpolate backend to ufunc
vlad-perevezentsev 0037455
Move def interp()to dpnp_iface_mathematical
vlad-perevezentsev 7866eb8
Use dispatch vector and remove interpolate_complex_impl
vlad-perevezentsev 51b3bde
Add more backend checks
vlad-perevezentsev ecfa37d
Add support left/right args
vlad-perevezentsev 5d53f9c
Use get_usm_allocations in def interp
vlad-perevezentsev 9dbc2c5
Pass idx as std::int64_t
vlad-perevezentsev 1bafd7c
Add proper casting input array
vlad-perevezentsev 2f43fd7
Update def interp to support period args
vlad-perevezentsev ae65091
Return fp[-1] instead of right_val for x==xp[-1]
vlad-perevezentsev 771d3eb
Unskip cupy tests for interp
vlad-perevezentsev 5cda3d2
Add dpnp tests for interp
vlad-perevezentsev a65a1dd
Update docstrings for def interp()
vlad-perevezentsev 3146234
Merge master into impl_of_interp
vlad-perevezentsev 99cc8b5
Remove lines after merging
vlad-perevezentsev 5ec0738
Merge master into impl_of_interp
vlad-perevezentsev 1263eb5
Add type_check flag to cupy tests
vlad-perevezentsev 7c1fdf1
Merge master into impl_of_interp
vlad-perevezentsev b84dd7e
Add common_interpolate_checks with common utils
vlad-perevezentsev e9e357c
Reuse IsNan from common utils
vlad-perevezentsev 50e4513
Remove dublicate copy
vlad-perevezentsev dbeb313
Add _validate_interp_param() function
vlad-perevezentsev dbb1b55
Impove code coverage
vlad-perevezentsev cbe7e7a
Add sycl_queue tests for interp
vlad-perevezentsev aa102bd
Add usm_type tests for interp()
vlad-perevezentsev 28b2a52
Merge master into impl_of_interp
vlad-perevezentsev 82c657e
Fix pre-commit remark
vlad-perevezentsev b89f41a
Move value_type_of to ext/common.hpp
vlad-perevezentsev 36ee455
Address remarks
vlad-perevezentsev 92d27d8
Merge master into impl_of_interp
vlad-perevezentsev 3b0eb60
Address the rest remarks
vlad-perevezentsev 051dc50
Merge master into impl_of_interp
vlad-perevezentsev 70611c2
helper files
vlad-perevezentsev 9e06cc3
Update value_type_of to support const complex type
vlad-perevezentsev ba987dd
Add check_same_dtype() to validation_utils.hpp
vlad-perevezentsev cbf49d4
Add check_has_dtype() to validation_utils.hpp
vlad-perevezentsev fa5d07a
Use check_num_dims for left/right
vlad-perevezentsev 0a4fdff
Use check_same_dtype for left/right
vlad-perevezentsev f368c17
Add vector vesion of check_num_dims to validation_utils.hpp
vlad-perevezentsev a7d2f50
Add check_same_size to validation_utils.hpp
vlad-perevezentsev 62802f8
Merge master into impl_of_interp
vlad-perevezentsev b97a92e
Remove std::remove_cv_t from value_type_of_impl
vlad-perevezentsev 0f8f453
Remove support for non-scalar period
vlad-perevezentsev 169e002
Delete trash files
vlad-perevezentsev 8bb77fa
Merge master into impl_of_interp
vlad-perevezentsev dd05ffe
Address remarks
vlad-perevezentsev b85234f
Update tests for interp()
vlad-perevezentsev 46ea738
Add support bool fp
vlad-perevezentsev 3c08905
Merge master into impl_of_interp
vlad-perevezentsev 4975148
Update _validate_interp_param()
vlad-perevezentsev 4bf2b03
Use same_kind casting
vlad-perevezentsev 00c9874
Address remarks
vlad-perevezentsev d1b3615
Merge master into impl_of_interp
vlad-perevezentsev 088e8a6
Merge branch 'master' into impl_of_interp
antonwolfy 3d205fd
Update test_errors
vlad-perevezentsev 6ef86be
Improve code coverage
vlad-perevezentsev 699469d
Address remark
vlad-perevezentsev 8049930
A small update _validate_interp_param
vlad-perevezentsev 926c856
Update CHANGELOG
vlad-perevezentsev ae9af0e
Address a minor comment
vlad-perevezentsev df36b25
Merge master into impl_of_interp
vlad-perevezentsev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
300 changes: 300 additions & 0 deletions
300
dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2025, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <complex> | ||
#include <vector> | ||
|
||
#include "dpctl4pybind11.hpp" | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/output_validation.hpp" | ||
#include "utils/type_dispatch.hpp" | ||
|
||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "kernels/elementwise_functions/interpolate.hpp" | ||
|
||
#include "ext/validation_utils.hpp" | ||
|
||
namespace py = pybind11; | ||
namespace td_ns = dpctl::tensor::type_dispatch; | ||
|
||
using ext::validation::array_names; | ||
using ext::validation::array_ptr; | ||
using ext::validation::common_checks; | ||
|
||
namespace dpnp::extensions::ufunc | ||
{ | ||
|
||
namespace impl | ||
{ | ||
|
||
template <typename T> | ||
struct value_type_of | ||
{ | ||
using type = T; | ||
}; | ||
|
||
template <typename T> | ||
struct value_type_of<std::complex<T>> | ||
{ | ||
using type = T; | ||
}; | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
template <typename T> | ||
using value_type_of_t = typename value_type_of<T>::type; | ||
|
||
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &, | ||
const void *, // x | ||
const void *, // idx | ||
const void *, // xp | ||
const void *, // fp | ||
const void *, // left | ||
const void *, // right | ||
void *, // out | ||
std::size_t, // n | ||
std::size_t, // xp_size | ||
const std::vector<sycl::event> &); | ||
|
||
template <typename T> | ||
sycl::event interpolate_call(sycl::queue &exec_q, | ||
const void *vx, | ||
const void *vidx, | ||
const void *vxp, | ||
const void *vfp, | ||
const void *vleft, | ||
const void *vright, | ||
void *vout, | ||
std::size_t n, | ||
std::size_t xp_size, | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const std::vector<sycl::event> &depends) | ||
{ | ||
using dpctl::tensor::type_utils::is_complex_v; | ||
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>; | ||
|
||
const TCoord *x = static_cast<const TCoord *>(vx); | ||
const std::int64_t *idx = static_cast<const std::int64_t *>(vidx); | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const TCoord *xp = static_cast<const TCoord *>(vxp); | ||
const T *fp = static_cast<const T *>(vfp); | ||
const T *left = static_cast<const T *>(vleft); | ||
const T *right = static_cast<const T *>(vright); | ||
T *out = static_cast<T *>(vout); | ||
|
||
using dpnp::kernels::interpolate::interpolate_impl; | ||
sycl::event interpolate_ev = interpolate_impl<TCoord, T>( | ||
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends); | ||
|
||
return interpolate_ev; | ||
} | ||
|
||
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types]; | ||
|
||
void common_interpolate_checks( | ||
const dpctl::tensor::usm_ndarray &x, | ||
const dpctl::tensor::usm_ndarray &idx, | ||
const dpctl::tensor::usm_ndarray &xp, | ||
const dpctl::tensor::usm_ndarray &fp, | ||
const dpctl::tensor::usm_ndarray &out, | ||
const std::optional<const dpctl::tensor::usm_ndarray> &left, | ||
const std::optional<const dpctl::tensor::usm_ndarray> &right) | ||
{ | ||
array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}}; | ||
|
||
auto array_types = td_ns::usm_ndarray_types(); | ||
int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum()); | ||
int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum()); | ||
int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum()); | ||
int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum()); | ||
|
||
if (x_type_id != xp_type_id) { | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
throw py::value_error("x and xp must have the same dtype"); | ||
} | ||
if (fp_type_id != out_type_id) { | ||
throw py::value_error("fp and out must have the same dtype"); | ||
} | ||
|
||
if (left) { | ||
const auto &l = left.value(); | ||
names.insert({&l, "left"}); | ||
if (l.get_ndim() != 0) { | ||
throw py::value_error("left must be a zero-dimensional array"); | ||
} | ||
|
||
int left_type_id = array_types.typenum_to_lookup_id(l.get_typenum()); | ||
if (left_type_id != fp_type_id) { | ||
throw py::value_error( | ||
"left must have the same dtype as fp and out"); | ||
} | ||
} | ||
|
||
if (right) { | ||
const auto &r = right.value(); | ||
names.insert({&r, "right"}); | ||
if (r.get_ndim() != 0) { | ||
throw py::value_error("right must be a zero-dimensional array"); | ||
} | ||
|
||
int right_type_id = array_types.typenum_to_lookup_id(r.get_typenum()); | ||
if (right_type_id != fp_type_id) { | ||
throw py::value_error( | ||
"right must have the same dtype as fp and out"); | ||
} | ||
} | ||
|
||
common_checks({&x, &xp, &fp, left ? &left.value() : nullptr, | ||
right ? &right.value() : nullptr}, | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{&out}, names); | ||
|
||
if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 || | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
idx.get_ndim() != 1 || out.get_ndim() != 1) | ||
{ | ||
throw py::value_error("All arrays must be one-dimensional"); | ||
} | ||
|
||
if (xp.get_size() != fp.get_size()) { | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
throw py::value_error("xp and fp must have the same size"); | ||
} | ||
|
||
if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) { | ||
throw py::value_error("x, idx, and out must have the same size"); | ||
} | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
py_interpolate(const dpctl::tensor::usm_ndarray &x, | ||
const dpctl::tensor::usm_ndarray &idx, | ||
const dpctl::tensor::usm_ndarray &xp, | ||
const dpctl::tensor::usm_ndarray &fp, | ||
std::optional<const dpctl::tensor::usm_ndarray> &left, | ||
std::optional<const dpctl::tensor::usm_ndarray> &right, | ||
dpctl::tensor::usm_ndarray &out, | ||
sycl::queue &exec_q, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
if (x.get_size() == 0) { | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return {sycl::event(), sycl::event()}; | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
common_interpolate_checks(x, idx, xp, fp, out, left, right); | ||
|
||
int out_typenum = out.get_typenum(); | ||
|
||
auto array_types = td_ns::usm_ndarray_types(); | ||
int out_type_id = array_types.typenum_to_lookup_id(out_typenum); | ||
|
||
auto fn = interpolate_dispatch_vector[out_type_id]; | ||
if (!fn) { | ||
throw py::type_error("Unsupported dtype"); | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
std::size_t n = x.get_size(); | ||
std::size_t xp_size = xp.get_size(); | ||
|
||
void *left_ptr = left ? left.value().get_data() : nullptr; | ||
void *right_ptr = right ? right.value().get_data() : nullptr; | ||
|
||
sycl::event ev = | ||
fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(), | ||
left_ptr, right_ptr, out.get_data(), n, xp_size, depends); | ||
|
||
sycl::event args_ev; | ||
|
||
if (left && right) { | ||
args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev}); | ||
} | ||
else if (left) { | ||
args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {x, idx, xp, fp, out, left.value()}, {ev}); | ||
} | ||
else if (right) { | ||
args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {x, idx, xp, fp, out, right.value()}, {ev}); | ||
} | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else { | ||
args_ev = | ||
dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev}); | ||
} | ||
|
||
return std::make_pair(args_ev, ev); | ||
} | ||
|
||
/** | ||
* @brief A factory to define pairs of supported types for which | ||
* interpolate function is available. | ||
* | ||
* @tparam T Type of input vector `a` and of result vector `y`. | ||
*/ | ||
template <typename T> | ||
struct InterpolateOutputType | ||
{ | ||
using value_type = typename std::disjunction< | ||
td_ns::TypeMapResultEntry<T, sycl::half>, | ||
antonwolfy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
td_ns::TypeMapResultEntry<T, float>, | ||
td_ns::TypeMapResultEntry<T, double>, | ||
td_ns::TypeMapResultEntry<T, std::complex<float>>, | ||
td_ns::TypeMapResultEntry<T, std::complex<double>>, | ||
td_ns::DefaultResultEntry<void>>::result_type; | ||
}; | ||
|
||
template <typename fnT, typename T> | ||
struct InterpolateFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (std::is_same_v< | ||
typename InterpolateOutputType<T>::value_type, void>) | ||
{ | ||
return nullptr; | ||
} | ||
else { | ||
return interpolate_call<T>; | ||
} | ||
} | ||
}; | ||
|
||
void init_interpolate_dispatch_vectors() | ||
{ | ||
using namespace td_ns; | ||
|
||
DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types> | ||
dtb_interpolate; | ||
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector); | ||
} | ||
|
||
} // namespace impl | ||
|
||
void init_interpolate(py::module_ m) | ||
{ | ||
impl::init_interpolate_dispatch_vectors(); | ||
|
||
using impl::py_interpolate; | ||
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"), | ||
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"), | ||
py::arg("out"), py::arg("sycl_queue"), | ||
py::arg("depends") = py::list()); | ||
} | ||
|
||
} // namespace dpnp::extensions::ufunc |
35 changes: 35 additions & 0 deletions
35
dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2025, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
namespace py = pybind11; | ||
|
||
namespace dpnp::extensions::ufunc | ||
{ | ||
void init_interpolate(py::module_ m); | ||
} // namespace dpnp::extensions::ufunc |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.