Skip to content

Implement dpnp.interp() #2417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a24d367
Initial impl of dpnp.inter()
vlad-perevezentsev Mar 25, 2025
e2b20b0
Second impl with dispatch_vector[only floating]
vlad-perevezentsev Apr 2, 2025
f7d1da9
Implement interpolate_complex
vlad-perevezentsev Apr 2, 2025
e1b8698
Move interpolate backend to ufunc
vlad-perevezentsev Apr 2, 2025
0037455
Move def interp()to dpnp_iface_mathematical
vlad-perevezentsev Apr 2, 2025
7866eb8
Use dispatch vector and remove interpolate_complex_impl
vlad-perevezentsev Apr 2, 2025
51b3bde
Add more backend checks
vlad-perevezentsev Apr 2, 2025
ecfa37d
Add support left/right args
vlad-perevezentsev Apr 10, 2025
5d53f9c
Use get_usm_allocations in def interp
vlad-perevezentsev Apr 10, 2025
9dbc2c5
Pass idx as std::int64_t
vlad-perevezentsev Apr 11, 2025
1bafd7c
Add proper casting input array
vlad-perevezentsev Apr 11, 2025
2f43fd7
Update def interp to support period args
vlad-perevezentsev Apr 11, 2025
ae65091
Return fp[-1] instead of right_val for x==xp[-1]
vlad-perevezentsev Apr 11, 2025
771d3eb
Unskip cupy tests for interp
vlad-perevezentsev Apr 11, 2025
5cda3d2
Add dpnp tests for interp
vlad-perevezentsev Apr 11, 2025
a65a1dd
Update docstrings for def interp()
vlad-perevezentsev Apr 11, 2025
3146234
Merge master into impl_of_interp
vlad-perevezentsev Apr 11, 2025
99cc8b5
Remove lines after merging
vlad-perevezentsev Apr 11, 2025
5ec0738
Merge master into impl_of_interp
vlad-perevezentsev Apr 11, 2025
1263eb5
Add type_check flag to cupy tests
vlad-perevezentsev Apr 14, 2025
7c1fdf1
Merge master into impl_of_interp
vlad-perevezentsev Apr 14, 2025
b84dd7e
Add common_interpolate_checks with common utils
vlad-perevezentsev Apr 14, 2025
e9e357c
Reuse IsNan from common utils
vlad-perevezentsev Apr 14, 2025
50e4513
Remove dublicate copy
vlad-perevezentsev Apr 14, 2025
dbeb313
Add _validate_interp_param() function
vlad-perevezentsev Apr 14, 2025
dbb1b55
Impove code coverage
vlad-perevezentsev Apr 15, 2025
cbe7e7a
Add sycl_queue tests for interp
vlad-perevezentsev Apr 15, 2025
aa102bd
Add usm_type tests for interp()
vlad-perevezentsev Apr 15, 2025
28b2a52
Merge master into impl_of_interp
vlad-perevezentsev Apr 15, 2025
82c657e
Fix pre-commit remark
vlad-perevezentsev Apr 15, 2025
b89f41a
Move value_type_of to ext/common.hpp
vlad-perevezentsev Apr 28, 2025
36ee455
Address remarks
vlad-perevezentsev Apr 28, 2025
92d27d8
Merge master into impl_of_interp
vlad-perevezentsev Apr 28, 2025
3b0eb60
Address the rest remarks
vlad-perevezentsev Apr 28, 2025
051dc50
Merge master into impl_of_interp
vlad-perevezentsev Apr 28, 2025
70611c2
helper files
vlad-perevezentsev Apr 29, 2025
9e06cc3
Update value_type_of to support const complex type
vlad-perevezentsev Apr 30, 2025
ba987dd
Add check_same_dtype() to validation_utils.hpp
vlad-perevezentsev Apr 30, 2025
cbf49d4
Add check_has_dtype() to validation_utils.hpp
vlad-perevezentsev Apr 30, 2025
fa5d07a
Use check_num_dims for left/right
vlad-perevezentsev Apr 30, 2025
0a4fdff
Use check_same_dtype for left/right
vlad-perevezentsev Apr 30, 2025
f368c17
Add vector vesion of check_num_dims to validation_utils.hpp
vlad-perevezentsev Apr 30, 2025
a7d2f50
Add check_same_size to validation_utils.hpp
vlad-perevezentsev Apr 30, 2025
62802f8
Merge master into impl_of_interp
vlad-perevezentsev Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dpnp/backend/extensions/common/ext/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ struct IsNan
}
};

template <typename T, bool hasValueType>
struct value_type_of_impl;

template <typename T>
struct value_type_of_impl<T, false>
{
using type = T;
};

template <typename T>
struct value_type_of_impl<T, true>
{
using type = typename std::remove_cv_t<T>::value_type;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation has a problem of not supporting const std::complex<T> case.

Better approach could be something like:

template<typename T, bool hasValueType>
struct value_type_of_impl;

template<typename T>
struct value_type_of_impl<T, false>
{
    using type = T;
};

template<typename T>
struct value_type_of_impl<T, true>
{
    using type = typename T::value_type;
};

template<T>
using value_type_of = value_type_of_impl<T, is_complex_v<T>>;

template<T>
using value_type_of_t = typename value_type_of<T>::type;

And in future is_complex could be replaced with proper has_value_type if needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is good point but I would prefer using type = typename std::remove_cv_t<T>::value_type to ensure it works correctly with const types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like?

struct value_type_of<std::complex<T>>
{
    using type = typename std::remove_cv_t<T>::value_type;
};

Then, I believe, it will not work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant removing const in value_type_of_impl<T, true>

template<typename T>
struct value_type_of_impl<T, true>
{
    using type = typename std::remove_cv_t<T>::value_type;
};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 9e06cc3


template <typename T>
using value_type_of = value_type_of_impl<T, type_utils::is_complex_v<T>>;

template <typename T>
using value_type_of_t = typename value_type_of<T>::type;

size_t get_max_local_size(const sycl::device &device);
size_t get_max_local_size(const sycl::device &device,
int cpu_local_size_limit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,17 @@
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include "ext/common.hpp"

#include "ext/validation_utils.hpp"
#include "utils/memory_overlap.hpp"

namespace td_ns = dpctl::tensor::type_dispatch;
namespace common = ext::common;

namespace ext::validation
{
inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
Expand Down Expand Up @@ -137,6 +145,15 @@ inline void check_num_dims(const array_ptr &arr,
}
}

inline void check_num_dims(const std::vector<array_ptr> &arrays,
const size_t ndim,
const array_names &names)
{
for (const auto &arr : arrays) {
check_num_dims(arr, ndim, names);
}
}

inline void check_max_dims(const array_ptr &arr,
const size_t max_ndim,
const array_names &names)
Expand All @@ -163,6 +180,103 @@ inline void check_size_at_least(const array_ptr &arr,
}
}

inline void check_has_dtype(const array_ptr &arr,
const typenum_t dtype,
const array_names &names)
{
if (arr == nullptr) {
return;
}

auto array_types = td_ns::usm_ndarray_types();
int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
int expected_type_id = static_cast<int>(dtype);

if (array_type_id != expected_type_id) {
py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);

std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
std::string(py::str(dtype_py)) + ", but got " +
std::string(py::str(actual_dtype));

throw py::value_error(msg);
}
}

inline void check_same_dtype(const array_ptr &arr1,
const array_ptr &arr2,
const array_names &names)
{
if (arr1 == nullptr || arr2 == nullptr) {
return;
}

auto array_types = td_ns::usm_ndarray_types();
int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());

if (first_type_id != second_type_id) {
py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
py::dtype second_dtype = common::dtype_from_typenum(second_type_id);

std::string msg = "Arrays " + name_of(arr1, names) + " and " +
name_of(arr2, names) +
" must have the same dtype, but got " +
std::string(py::str(first_dtype)) + " and " +
std::string(py::str(second_dtype));

throw py::value_error(msg);
}
}

inline void check_same_dtype(const std::vector<array_ptr> &arrays,
const array_names &names)
{
if (arrays.empty()) {
return;
}

const auto *first = arrays[0];
for (size_t i = 1; i < arrays.size(); ++i) {
check_same_dtype(first, arrays[i], names);
}
}

inline void check_same_size(const array_ptr &arr1,
const array_ptr &arr2,
const array_names &names)
{
if (arr1 == nullptr || arr2 == nullptr) {
return;
}

auto size1 = arr1->get_size();
auto size2 = arr2->get_size();

if (size1 != size2) {
std::string msg =
"Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
" must have the same size, but got " + std::to_string(size1) +
" and " + std::to_string(size2);

throw py::value_error(msg);
}
}

inline void check_same_size(const std::vector<array_ptr> &arrays,
const array_names &names)
{
if (arrays.empty()) {
return;
}

auto first = arrays[0];
for (size_t i = 1; i < arrays.size(); ++i) {
check_same_size(first, arrays[i], names);
}
}

inline void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names)
Expand Down
18 changes: 18 additions & 0 deletions dpnp/backend/extensions/common/ext/validation_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace ext::validation
{
using array_ptr = const dpctl::tensor::usm_ndarray *;
using array_names = std::unordered_map<array_ptr, std::string>;
using dpctl::tensor::type_dispatch::typenum_t;

std::string name_of(const array_ptr &arr, const array_names &names);

Expand All @@ -56,6 +57,9 @@ void check_no_overlap(const std::vector<array_ptr> &inputs,
void check_num_dims(const array_ptr &arr,
const size_t ndim,
const array_names &names);
void check_num_dims(const std::vector<array_ptr> &arrays,
const size_t ndim,
const array_names &names);
void check_max_dims(const array_ptr &arr,
const size_t max_ndim,
const array_names &names);
Expand All @@ -64,6 +68,20 @@ void check_size_at_least(const array_ptr &arr,
const size_t size,
const array_names &names);

void check_has_dtype(const array_ptr &arr,
const typenum_t dtype,
const array_names &names);

void check_same_dtype(const array_ptr &arr1,
const array_ptr &arr2,
const array_names &names);

void check_same_size(const array_ptr &arr1,
const array_ptr &arr2,
const array_names &names);
void check_same_size(const std::vector<array_ptr> &arrays,
const array_names &names);

void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names);
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/ufunc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
Expand Down Expand Up @@ -69,6 +70,7 @@ endif()
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)

target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)

target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "gcd.hpp"
#include "heaviside.hpp"
#include "i0.hpp"
#include "interpolate.hpp"
#include "lcm.hpp"
#include "ldexp.hpp"
#include "logaddexp2.hpp"
Expand Down Expand Up @@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
init_gcd(m);
init_heaviside(m);
init_i0(m);
init_interpolate(m);
init_lcm(m);
init_ldexp(m);
init_logaddexp2(m);
Expand Down
Loading
Loading