Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COMPAT: prepare for pandas 3.0 string dtype #493

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyogrio/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PANDAS_GE_15 = pandas is not None and Version(pandas.__version__) >= Version("1.5.0")
PANDAS_GE_20 = pandas is not None and Version(pandas.__version__) >= Version("2.0.0")
PANDAS_GE_22 = pandas is not None and Version(pandas.__version__) >= Version("2.2.0")
PANDAS_GE_30 = pandas is not None and Version(pandas.__version__) >= Version("3.0.0dev")

GDAL_GE_352 = __gdal_version__ >= (3, 5, 2)
GDAL_GE_38 = __gdal_version__ >= (3, 8, 0)
Expand Down
19 changes: 16 additions & 3 deletions pyogrio/geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import numpy as np

from pyogrio._compat import HAS_GEOPANDAS, PANDAS_GE_15, PANDAS_GE_20, PANDAS_GE_22
from pyogrio._compat import (
HAS_GEOPANDAS,
PANDAS_GE_15,
PANDAS_GE_20,
PANDAS_GE_22,
PANDAS_GE_30,
)
from pyogrio.errors import DataSourceError
from pyogrio.raw import (
DRIVERS_NO_MIXED_DIMENSIONS,
Expand Down Expand Up @@ -52,13 +58,13 @@ def _try_parse_datetime(ser):
except Exception:
res = ser
# if object dtype, try parse as utc instead
if res.dtype == "object":
if res.dtype in ("object", "string"):
try:
res = pd.to_datetime(ser, utc=True, **datetime_kwargs)
except Exception:
pass

if res.dtype != "object":
if res.dtype.kind == "M":
# GDAL only supports ms precision, convert outputs to match.
# Pandas 2.0 supports datetime[ms] directly, prior versions only support [ns],
# Instead, round the values to [ms] precision.
Expand Down Expand Up @@ -282,11 +288,18 @@ def read_dataframe(
)

if use_arrow:
import pyarrow as pa

meta, table = result

# split_blocks and self_destruct decrease memory usage, but have as side effect
# that accessing table afterwards causes crash, so del table to avoid.
kwargs = {"self_destruct": True}
if PANDAS_GE_30:
kwargs["types_mapper"] = {
pa.string(): pd.StringDtype(na_value=np.nan),
pa.large_string(): pd.StringDtype(na_value=np.nan),
}.get
if arrow_to_pandas_kwargs is not None:
kwargs.update(arrow_to_pandas_kwargs)
df = table.to_pandas(**kwargs)
Expand Down
33 changes: 24 additions & 9 deletions pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,20 @@ def test_read_layer(tmp_path, use_arrow):

# create a multilayer GPKG
expected1 = gp.GeoDataFrame(geometry=[Point(0, 0)], crs="EPSG:4326")
if use_arrow:
# TODO this needs to be fixed on the geopandas side (to ensure the
# GeoDataFrame() constructor does this), when use_arrow we already
# get columns Index with string dtype
expected1.columns = expected1.columns.astype("str")
write_dataframe(
expected1,
filename,
layer="layer1",
)

expected2 = gp.GeoDataFrame(geometry=[Point(1, 1)], crs="EPSG:4326")
if use_arrow:
expected2.columns = expected2.columns.astype("str")
write_dataframe(expected2, filename, layer="layer2", append=True)

assert np.array_equal(
Expand Down Expand Up @@ -378,7 +385,7 @@ def test_read_null_values(tmp_path, use_arrow):
df = read_dataframe(filename, use_arrow=use_arrow, read_geometry=False)

# make sure that Null values are preserved
assert np.array_equal(df.col.values, expected.col.values)
assert df["col"].isna().all()


def test_read_fid_as_index(naturalearth_lowres_all_ext, use_arrow):
Expand Down Expand Up @@ -692,6 +699,13 @@ def test_read_skip_features(naturalearth_lowres_all_ext, use_arrow, skip_feature
# In .geojsonl the vertices are reordered, so normalize
is_jsons = ext == ".geojsonl"

if skip_features == 200 and not use_arrow:
# result is an empty dataframe, so no proper dtype inference happens
# for the numpy object dtype arrays
df[["continent", "name", "iso_a3"]] = df[
["continent", "name", "iso_a3"]
].astype("str")

assert_geodataframe_equal(
df,
expected,
Expand Down Expand Up @@ -1549,11 +1563,12 @@ def test_write_read_mixed_column_values(tmp_path):
write_dataframe(test_gdf, output_path)
output_gdf = read_dataframe(output_path)
assert len(test_gdf) == len(output_gdf)
for idx, value in enumerate(mixed_values):
if value in (None, np.nan):
assert output_gdf["mixed"][idx] is None
else:
assert output_gdf["mixed"][idx] == str(value)
# mixed values as object dtype are currently written as strings
expected = pd.Series(
[str(value) if value not in (None, np.nan) else None for value in mixed_values],
name="mixed",
)
assert_series_equal(output_gdf["mixed"], expected)


@requires_arrow_write_api
Expand Down Expand Up @@ -1586,8 +1601,8 @@ def test_write_read_null(tmp_path, use_arrow):
assert pd.isna(result_gdf["float64"][1])
assert pd.isna(result_gdf["float64"][2])
assert result_gdf["object_str"][0] == "test"
assert result_gdf["object_str"][1] is None
assert result_gdf["object_str"][2] is None
assert pd.isna(result_gdf["object_str"][1])
assert pd.isna(result_gdf["object_str"][2])


@pytest.mark.requires_arrow_write_api
Expand Down Expand Up @@ -1854,7 +1869,7 @@ def test_write_nullable_dtypes(tmp_path, use_arrow):
expected["col2"] = expected["col2"].astype("float64")
expected["col3"] = expected["col3"].astype("float32")
expected["col4"] = expected["col4"].astype("float64")
expected["col5"] = expected["col5"].astype(object)
expected["col5"] = expected["col5"].astype("str")
expected.loc[1, "col5"] = None # pandas converts to pd.NA on line above
assert_geodataframe_equal(output_gdf, expected)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,6 @@ section-order = [
"geopandas.tests",
"geopandas.testing",
]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
Loading