Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make type annotations for NumPy arrays more specific #1358

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import sys
import numpy as np
import numpy.typing as npt

from typing import Generic, Any, Callable, overload
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -289,7 +290,7 @@ def distance_matrix(
parallel_threshold: int = ...,
as_undirected: bool = ...,
null_value: float = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def unweighted_average_shortest_path_length(
graph: PyGraph | PyDiGraph,
parallel_threshold: int = ...,
Expand All @@ -300,7 +301,7 @@ def adjacency_matrix(
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
null_value: float = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def all_simple_paths(
graph: PyGraph | PyDiGraph,
from_: int,
Expand All @@ -319,13 +320,13 @@ def floyd_warshall_numpy(
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
parallel_threshold: int = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def floyd_warshall_successor_and_distance(
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float | None = ...,
parallel_threshold: int | None = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
def astar_shortest_path(
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
node: int,
Expand Down
33 changes: 18 additions & 15 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from rustworkx import generators # noqa
from typing_extensions import Self

import numpy as np
import numpy.typing as npt
import sys

if sys.version_info >= (3, 13):
Expand Down Expand Up @@ -206,15 +207,15 @@ def digraph_adjacency_matrix(
default_weight: float = ...,
null_value: float = ...,
parallel_edge: str = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def graph_adjacency_matrix(
graph: PyGraph[_S, _T],
/,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
null_value: float = ...,
parallel_edge: str = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def cycle_basis(graph: PyGraph, /, root: int | None = ...) -> list[list[int]]: ...
def articulation_points(graph: PyGraph, /) -> set[int]: ...
def bridges(graph: PyGraph, /) -> set[tuple[int]]: ...
Expand Down Expand Up @@ -595,14 +596,14 @@ def undirected_gnp_random_graph(
) -> PyGraph: ...
def directed_sbm_random_graph(
sizes: list[int],
probabilities: np.ndarray,
probabilities: npt.NDArray[np.float64],
loops: bool,
/,
seed: int | None = ...,
) -> PyDiGraph: ...
def undirected_sbm_random_graph(
sizes: list[int],
probabilities: np.ndarray,
probabilities: npt.NDArray[np.float64],
loops: bool,
/,
seed: int | None = ...,
Expand Down Expand Up @@ -863,13 +864,13 @@ def digraph_distance_matrix(
parallel_threshold: int | None = ...,
as_undirected: bool | None = ...,
null_value: float | None = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def graph_distance_matrix(
graph: PyGraph,
/,
parallel_threshold: int | None = ...,
null_value: float | None = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def digraph_floyd_warshall(
graph: PyDiGraph[_S, _T],
/,
Expand All @@ -892,29 +893,29 @@ def digraph_floyd_warshall_numpy(
as_undirected: bool | None = ...,
default_weight: float | None = ...,
parallel_threshold: int | None = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def graph_floyd_warshall_numpy(
graph: PyGraph[_S, _T],
/,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float | None = ...,
parallel_threshold: int | None = ...,
) -> np.ndarray: ...
) -> npt.NDArray[np.float64]: ...
def digraph_floyd_warshall_successor_and_distance(
graph: PyDiGraph[_S, _T],
/,
weight_fn: Callable[[_T], float] | None = ...,
as_undirected: bool | None = ...,
default_weight: float | None = ...,
parallel_threshold: int | None = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
def graph_floyd_warshall_successor_and_distance(
graph: PyGraph[_S, _T],
/,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float | None = ...,
parallel_threshold: int | None = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
def find_negative_cycle(
graph: PyDiGraph[_S, _T],
edge_cost_fn: Callable[[_T], float],
Expand Down Expand Up @@ -1079,7 +1080,9 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
def __len__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: Sequence[_T_co]) -> None: ...
def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ...
def __array__(
self, dtype: np.dtype[Any] | None = ..., copy: bool | None = ...
) -> npt.NDArray[Any]: ...
Comment on lines +1084 to +1085
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use a generic/typevar here? Since the Any for the dtype argument needs to match the dtype inside the returned array?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record we ignore the dtype at runtime. However, I cannot annotate the ndarray with _T as that would be incorrect. There is a mapping from _T to numpy types that we’d need to manually add and that is a lot of work.

Also, this method exists to support np.array calls no one should call it directly. I think NumPy type stubs will mask this call anyway

def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...

Expand Down Expand Up @@ -1235,11 +1238,11 @@ class PyGraph(Generic[_S, _T]):
) -> int | None: ...
@staticmethod
def from_adjacency_matrix(
matrix: np.ndarray, /, null_value: float = ...
matrix: npt.NDArray[np.float64], /, null_value: float = ...
) -> PyGraph[int, float]: ...
@staticmethod
def from_complex_adjacency_matrix(
matrix: np.ndarray, /, null_value: complex = ...
matrix: npt.NDArray[np.complex64], /, null_value: complex = ...
) -> PyGraph[int, complex]: ...
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[_T]: ...
def get_edge_data(self, node_a: int, node_b: int, /) -> _T: ...
Expand Down Expand Up @@ -1400,11 +1403,11 @@ class PyDiGraph(Generic[_S, _T]):
) -> list[_S]: ...
@staticmethod
def from_adjacency_matrix(
matrix: np.ndarray, /, null_value: float = ...
matrix: npt.NDArray[np.float64], /, null_value: float = ...
) -> PyDiGraph[int, float]: ...
@staticmethod
def from_complex_adjacency_matrix(
matrix: np.ndarray, /, null_value: complex = ...
matrix: npt.NDArray[np.complex64], /, null_value: complex = ...
) -> PyDiGraph[int, complex]: ...
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[_T]: ...
def get_edge_data(self, node_a: int, node_b: int, /) -> _T: ...
Expand Down
Loading