-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
String dtype: implement object-dtype based StringArray variant with NumPy semantics #58451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
63a7fc5
0eee625
607b95e
bca157d
79eb3b4
c063298
ab96aa4
bae8d65
31f1c33
cbd0820
864c166
d3ad7b0
028dc2c
1750bcb
7f4baf7
fdf1454
fe6fce6
70325d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -790,6 +790,16 @@ def assert_extension_array_equal( | |
left_na, right_na, obj=f"{obj} NA mask", index_values=index_values | ||
) | ||
|
||
# Specifically for StringArrayNumpySemantics, validate here we have a valid array | ||
if isinstance(left.dtype, StringDtype) and left.dtype.storage == "python_numpy": | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert np.all( | ||
[np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined] | ||
), "wrong missing value sentinels" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit a custom check (and we don't do anything similarly for other types), but given I initially overlooked a case where we were creating string arrays with the wrong missing value sentinel because the tests don't actually catch that (two arrays with different missing value sentinels still pass as equal in case of EAs), I would prefer keeping this in at least on the short term. |
||
if isinstance(right.dtype, StringDtype) and right.dtype.storage == "python_numpy": | ||
assert np.all( | ||
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined] | ||
), "wrong missing value sentinels" | ||
|
||
left_valid = left[~left_na].to_numpy(dtype=object) | ||
right_valid = right[~right_na].to_numpy(dtype=object) | ||
if check_exact: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
import operator | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
ClassVar, | ||
Literal, | ||
cast, | ||
) | ||
|
||
import numpy as np | ||
|
||
from pandas._config import get_option | ||
from pandas._config import ( | ||
get_option, | ||
using_pyarrow_string_dtype, | ||
) | ||
|
||
from pandas._libs import ( | ||
lib, | ||
missing as libmissing, | ||
) | ||
from pandas._libs.arrays import NDArrayBacked | ||
from pandas._libs.lib import ensure_string_array | ||
from pandas.compat import pa_version_under10p1 | ||
from pandas.compat import ( | ||
HAS_PYARROW, | ||
pa_version_under10p1, | ||
) | ||
from pandas.compat.numpy import function as nv | ||
from pandas.util._decorators import doc | ||
|
||
|
@@ -81,7 +89,7 @@ class StringDtype(StorageExtensionDtype): | |
|
||
Parameters | ||
---------- | ||
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional | ||
storage : {"python", "pyarrow", "python_numpy", "pyarrow_numpy"}, optional | ||
If not given, the value of ``pd.options.mode.string_storage``. | ||
|
||
Attributes | ||
|
@@ -113,7 +121,7 @@ class StringDtype(StorageExtensionDtype): | |
# follows NumPy semantics, which uses nan. | ||
@property | ||
def na_value(self) -> libmissing.NAType | float: # type: ignore[override] | ||
if self.storage == "pyarrow_numpy": | ||
if self.storage in ("pyarrow_numpy", "python_numpy"): | ||
return np.nan | ||
else: | ||
return libmissing.NA | ||
|
@@ -122,15 +130,17 @@ def na_value(self) -> libmissing.NAType | float: # type: ignore[override] | |
|
||
def __init__(self, storage=None) -> None: | ||
if storage is None: | ||
infer_string = get_option("future.infer_string") | ||
if infer_string: | ||
storage = "pyarrow_numpy" | ||
if using_pyarrow_string_dtype(): | ||
if HAS_PYARROW: | ||
storage = "pyarrow_numpy" | ||
else: | ||
storage = "python_numpy" | ||
else: | ||
storage = get_option("mode.string_storage") | ||
if storage not in {"python", "pyarrow", "pyarrow_numpy"}: | ||
if storage not in {"python", "pyarrow", "python_numpy", "pyarrow_numpy"}: | ||
raise ValueError( | ||
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. " | ||
f"Got {storage} instead." | ||
"Storage must be 'python', 'pyarrow', 'python_numpy' or 'pyarrow_numpy'" | ||
f". Got {storage} instead." | ||
) | ||
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1: | ||
raise ImportError( | ||
|
@@ -178,6 +188,8 @@ def construct_from_string(cls, string) -> Self: | |
return cls() | ||
elif string == "string[python]": | ||
return cls(storage="python") | ||
elif string == "string[python_numpy]": | ||
return cls(storage="python_numpy") | ||
elif string == "string[pyarrow]": | ||
return cls(storage="pyarrow") | ||
elif string == "string[pyarrow_numpy]": | ||
|
@@ -207,6 +219,8 @@ def construct_array_type( # type: ignore[override] | |
return StringArray | ||
elif self.storage == "pyarrow": | ||
return ArrowStringArray | ||
elif self.storage == "python_numpy": | ||
return StringArrayNumpySemantics | ||
else: | ||
return ArrowStringArrayNumpySemantics | ||
|
||
|
@@ -238,7 +252,7 @@ def __from_arrow__( | |
# convert chunk by chunk to numpy and concatenate then, to avoid | ||
# overflow for large string data when concatenating the pyarrow arrays | ||
arr = arr.to_numpy(zero_copy_only=False) | ||
arr = ensure_string_array(arr, na_value=libmissing.NA) | ||
arr = ensure_string_array(arr, na_value=self.na_value) | ||
results.append(arr) | ||
|
||
if len(chunks) == 0: | ||
|
@@ -248,11 +262,7 @@ def __from_arrow__( | |
|
||
# Bypass validation inside StringArray constructor, see GH#47781 | ||
new_string_array = StringArray.__new__(StringArray) | ||
NDArrayBacked.__init__( | ||
new_string_array, | ||
arr, | ||
StringDtype(storage="python"), | ||
) | ||
NDArrayBacked.__init__(new_string_array, arr, self) | ||
return new_string_array | ||
|
||
|
||
|
@@ -360,14 +370,15 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc] | |
|
||
# undo the NumpyExtensionArray hack | ||
_typ = "extension" | ||
_storage = "python" | ||
|
||
def __init__(self, values, copy: bool = False) -> None: | ||
values = extract_array(values) | ||
|
||
super().__init__(values, copy=copy) | ||
if not isinstance(values, type(self)): | ||
self._validate() | ||
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python")) | ||
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage=self._storage)) | ||
|
||
def _validate(self) -> None: | ||
"""Validate that we only store NA or strings.""" | ||
|
@@ -385,22 +396,41 @@ def _validate(self) -> None: | |
else: | ||
lib.convert_nans_to_NA(self._ndarray) | ||
|
||
def _validate_scalar(self, value): | ||
# used by NDArrayBackedExtensionIndex.insert | ||
if isna(value): | ||
return self.dtype.na_value | ||
elif not isinstance(value, str): | ||
raise TypeError( | ||
f"Cannot set non-string value '{value}' into a string array." | ||
) | ||
return value | ||
|
||
@classmethod | ||
def _from_sequence( | ||
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False | ||
) -> Self: | ||
if dtype and not (isinstance(dtype, str) and dtype == "string"): | ||
dtype = pandas_dtype(dtype) | ||
assert isinstance(dtype, StringDtype) and dtype.storage == "python" | ||
assert isinstance(dtype, StringDtype) and dtype.storage in ( | ||
"python", | ||
"python_numpy", | ||
) | ||
else: | ||
if get_option("future.infer_string"): | ||
dtype = StringDtype(storage="python_numpy") | ||
else: | ||
dtype = StringDtype(storage="python") | ||
|
||
from pandas.core.arrays.masked import BaseMaskedArray | ||
|
||
na_value = dtype.na_value | ||
if isinstance(scalars, BaseMaskedArray): | ||
# avoid costly conversion to object dtype | ||
na_values = scalars._mask | ||
result = scalars._data | ||
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False) | ||
result[na_values] = libmissing.NA | ||
result[na_values] = na_value | ||
|
||
else: | ||
if lib.is_pyarrow_array(scalars): | ||
|
@@ -409,12 +439,12 @@ def _from_sequence( | |
# zero_copy_only to True which caused problems see GH#52076 | ||
scalars = np.array(scalars) | ||
# convert non-na-likes to str, and nan-likes to StringDtype().na_value | ||
result = lib.ensure_string_array(scalars, na_value=libmissing.NA, copy=copy) | ||
result = lib.ensure_string_array(scalars, na_value=na_value, copy=copy) | ||
|
||
# Manually creating new array avoids the validation step in the __init__, so is | ||
# faster. Refactor need for validation? | ||
new_string_array = cls.__new__(cls) | ||
NDArrayBacked.__init__(new_string_array, result, StringDtype(storage="python")) | ||
NDArrayBacked.__init__(new_string_array, result, dtype) | ||
|
||
return new_string_array | ||
|
||
|
@@ -464,7 +494,7 @@ def __setitem__(self, key, value) -> None: | |
# validate new items | ||
if scalar_value: | ||
if isna(value): | ||
value = libmissing.NA | ||
value = self.dtype.na_value | ||
elif not isinstance(value, str): | ||
raise TypeError( | ||
f"Cannot set non-string value '{value}' into a StringArray." | ||
|
@@ -478,7 +508,7 @@ def __setitem__(self, key, value) -> None: | |
mask = isna(value) | ||
if mask.any(): | ||
value = value.copy() | ||
value[isna(value)] = libmissing.NA | ||
value[isna(value)] = self.dtype.na_value | ||
|
||
super().__setitem__(key, value) | ||
|
||
|
@@ -600,9 +630,9 @@ def _cmp_method(self, other, op): | |
|
||
if op.__name__ in ops.ARITHMETIC_BINOPS: | ||
result = np.empty_like(self._ndarray, dtype="object") | ||
result[mask] = libmissing.NA | ||
result[mask] = self.dtype.na_value | ||
result[valid] = op(self._ndarray[valid], other) | ||
return StringArray(result) | ||
return self._from_backing_data(result) | ||
else: | ||
# logical | ||
result = np.zeros(len(self._ndarray), dtype="bool") | ||
|
@@ -671,3 +701,106 @@ def _str_map( | |
# or .findall returns a list). | ||
# -> We don't know the result type. E.g. `.get` can return anything. | ||
return lib.map_infer_mask(arr, f, mask.view("uint8")) | ||
|
||
|
||
class StringArrayNumpySemantics(StringArray): | ||
_storage = "python_numpy" | ||
|
||
def _validate(self) -> None: | ||
"""Validate that we only store NaN or strings.""" | ||
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): | ||
raise ValueError( | ||
"StringArrayNumpySemantics requires a sequence of strings or NaN" | ||
) | ||
if self._ndarray.dtype != "object": | ||
raise ValueError( | ||
"StringArrayNumpySemantics requires a sequence of strings or NaN. Got " | ||
f"'{self._ndarray.dtype}' dtype instead." | ||
) | ||
# TODO validate or force NA/None to NaN | ||
|
||
@classmethod | ||
def _from_sequence( | ||
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False | ||
) -> Self: | ||
if dtype is None: | ||
dtype = StringDtype(storage="python_numpy") | ||
return super()._from_sequence(scalars, dtype=dtype, copy=copy) | ||
|
||
def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics: | ||
# need to overrde NumpyExtensionArray._from_backing_data to ensure | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# we always preserve the dtype | ||
return NDArrayBacked._from_backing_data(self, arr) | ||
|
||
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any: | ||
# the masked_reductions use pd.NA | ||
if result is libmissing.NA: | ||
return np.nan | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might want to return self._na_value here to make things explicit |
||
return super()._wrap_reduction_result(axis, result) | ||
|
||
def _cmp_method(self, other, op): | ||
result = super()._cmp_method(other, op) | ||
if op == operator.ne: | ||
return result.to_numpy(np.bool_, na_value=True) | ||
else: | ||
return result.to_numpy(np.bool_, na_value=False) | ||
|
||
def value_counts(self, dropna: bool = True) -> Series: | ||
from pandas.core.algorithms import value_counts_internal as value_counts | ||
|
||
result = value_counts(self._ndarray, sort=False, dropna=dropna) | ||
result.index = result.index.astype(self.dtype) | ||
return result | ||
|
||
# ------------------------------------------------------------------------ | ||
# String methods interface | ||
_str_na_value = np.nan | ||
|
||
def _str_map( | ||
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True | ||
): | ||
if dtype is None: | ||
dtype = self.dtype | ||
if na_value is None: | ||
na_value = self.dtype.na_value | ||
|
||
mask = isna(self) | ||
arr = np.asarray(self) | ||
convert = convert and not np.all(mask) | ||
|
||
if is_integer_dtype(dtype) or is_bool_dtype(dtype): | ||
na_value_is_na = isna(na_value) | ||
if na_value_is_na: | ||
if is_integer_dtype(dtype): | ||
na_value = 0 | ||
else: | ||
na_value = True | ||
|
||
result = lib.map_infer_mask( | ||
arr, | ||
f, | ||
mask.view("uint8"), | ||
convert=False, | ||
na_value=na_value, | ||
dtype=np.dtype(cast(type, dtype)), | ||
) | ||
if na_value_is_na and mask.any(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this method (which has now been refactored to _str_map_nan_semantics) is slightly different in StringArray vs ArrowStringArray and im trying to sort out whether the differences are intentional or just cosmetic. could use some help from the author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woops, my claim in 3 about it not mattering was incorrect. it matters for test_contains_nan and test_empty_str_methods There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Although an author who wrote this code almost 4 months ago ;) Will take a closer look at it later today, but one quick find is that there were changes to the arrow version after I started this PR, so I might not have taken those into account in this version, eg #58483 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ive convinced myself that the arrow version doesnt need the na_value_is_na check bc it is always True There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... and that 'convert' is never used |
||
if is_integer_dtype(dtype): | ||
result = result.astype("float64") | ||
else: | ||
result = result.astype("object") | ||
result[mask] = np.nan | ||
return result | ||
|
||
elif is_string_dtype(dtype) and not is_object_dtype(dtype): | ||
# i.e. StringDtype | ||
result = lib.map_infer_mask( | ||
arr, f, mask.view("uint8"), convert=False, na_value=na_value | ||
) | ||
return type(self)(result) | ||
else: | ||
# This is when the result type is object. We reach this when | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this raise an error or not be possible in the first place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some str methods are weird (i.e. what's In the comment here) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And not only weird, there are some methods that genuinely return an object dtype (of course because of lack of a better proper dtype, but right not with the default dtype this is object dtype). For example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. The list-returning functions are more good use cases for PDEP-13 #58455 |
||
# -> We know the result type is truly object (e.g. .encode returns bytes | ||
# or .findall returns a list). | ||
# -> We don't know the result type. E.g. `.get` can return anything. | ||
return lib.map_infer_mask(arr, f, mask.view("uint8")) |
Uh oh!
There was an error while loading. Please reload this page.