Skip to content

Commit fd2a12e

Browse files
Refactoring and adding tests.
1 parent 280cf52 commit fd2a12e

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

pins/drivers.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,36 +196,48 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
196196

197197
def default_title(obj, name):
198198
try:
199-
_choose_df_lib(obj)
199+
df_lib = _choose_df_lib(obj)
200200
except NotImplementedError:
201201
obj_name = type(obj).__qualname__
202202
return f"{name}: a pinned {obj_name} object"
203203

204+
_df_lib_to_objname: dict[_DFLib, str] = {
205+
"polars": "DataFrame",
206+
"pandas": "DataFrame",
207+
}
208+
204209
# TODO(compat): title says CSV rather than data.frame
205210
# see https://github.com/machow/pins-python/issues/5
206211
shape_str = " x ".join(map(str, obj.shape))
207-
return f"{name}: a pinned {shape_str} DataFrame"
212+
return f"{name}: a pinned {shape_str} {_df_lib_to_objname[df_lib]}"
208213

209214

210215
def _choose_df_lib(
211216
df,
212217
*,
213-
supported_libs: list[_DFLib] = ["pandas", "polars"],
218+
supported_libs: list[_DFLib] | None = None,
214219
file_type: str | None = None,
215220
) -> _DFLib:
216-
"""Return the type of DataFrame library used in the given DataFrame.
221+
"""Return the library associated with a DataFrame, e.g. "pandas".
222+
223+
The arguments `supported_libs` and `file_type` must be specified together, and are
224+
meant to be used when saving an object, to choose the appropriate library.
217225
218226
Args:
219227
df:
220228
The object to check - might not be a DataFrame necessarily.
221229
supported_libs:
222230
The DataFrame libraries to accept for this df.
223231
file_type:
224-
The file type we're trying to save to - used to give more specific error messages.
232+
The file type we're trying to save to - used to give more specific error
233+
messages.
225234
226235
Raises:
227-
NotImplementedError: If the DataFrame type is not recognized.
236+
NotImplementedError: If the DataFrame type is not recognized, or not supported.
228237
"""
238+
if (supported_libs is None) + (file_type is None) == 1:
239+
raise ValueError("Must provide both or neither of supported_libs and file_type")
240+
229241
df_libs: list[_DFLib] = []
230242

231243
# pandas
@@ -243,6 +255,7 @@ def _choose_df_lib(
243255
if isinstance(df, pl.DataFrame):
244256
df_libs.append("polars")
245257

258+
# Make sure there's only one library associated with the dataframe
246259
if len(df_libs) == 1:
247260
(df_lib,) = df_libs
248261
elif len(df_libs) > 1:
@@ -255,16 +268,14 @@ def _choose_df_lib(
255268
else:
256269
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
257270

258-
if df_lib not in supported_libs:
259-
if file_type is None:
260-
ftype_clause = "in pins"
261-
else:
262-
ftype_clause = f"for type {file_type!r}"
271+
# Raise if the library is not supported
272+
if supported_libs is not None and df_lib not in supported_libs:
273+
ftype_clause = f"for type {file_type!r}"
263274

264275
if len(supported_libs) == 1:
265276
msg = (
266277
f"Currently only {supported_libs[0]} DataFrames can be saved "
267-
f"{ftype_clause}. {df_lib} DataFrames are not yet supported."
278+
f"{ftype_clause}. DataFrames from {df_lib} are not yet supported."
268279
)
269280
else:
270281
msg = (

pins/tests/test_drivers.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from pins.config import PINS_ENV_INSECURE_READ
9-
from pins.drivers import default_title, load_data, save_data
9+
from pins.drivers import _choose_df_lib, default_title, load_data, save_data
1010
from pins.errors import PinsInsecureReadError
1111
from pins.meta import MetaRaw
1212
from pins.tests.helpers import rm_env
@@ -191,3 +191,57 @@ def test_driver_apply_suffix_false(tmp_path: Path):
191191
res_fname = save_data(df, p_obj, type_, apply_suffix=False)
192192

193193
assert Path(res_fname).name == "some_df"
194+
195+
196+
class TestChooseDFLib:
197+
def test_pandas(self):
198+
assert _choose_df_lib(pd.DataFrame({"x": [1]})) == "pandas"
199+
200+
def test_polars(self):
201+
assert _choose_df_lib(pl.DataFrame({"x": [1]})) == "polars"
202+
203+
def test_list_raises(self):
204+
with pytest.raises(
205+
NotImplementedError, match="Unrecognized DataFrame type: <class 'list'>"
206+
):
207+
_choose_df_lib([])
208+
209+
def test_pandas_subclass(self):
210+
class MyDataFrame(pd.DataFrame):
211+
pass
212+
213+
assert _choose_df_lib(MyDataFrame({"x": [1]})) == "pandas"
214+
215+
def test_ftype_compatible(self):
216+
assert (
217+
_choose_df_lib(
218+
pd.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
219+
)
220+
== "pandas"
221+
)
222+
223+
def test_ftype_incompatible(self):
224+
with pytest.raises(
225+
NotImplementedError,
226+
match=(
227+
"Currently only pandas DataFrames can be saved for type 'csv'. "
228+
"DataFrames from polars are not yet supported."
229+
),
230+
):
231+
_choose_df_lib(
232+
pl.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
233+
)
234+
235+
def test_supported_alone_raises(self):
236+
with pytest.raises(
237+
ValueError,
238+
match="Must provide both or neither of supported_libs and file_type",
239+
):
240+
_choose_df_lib(..., supported_libs=["pandas"])
241+
242+
def test_file_type_alone_raises(self):
243+
with pytest.raises(
244+
ValueError,
245+
match="Must provide both or neither of supported_libs and file_type",
246+
):
247+
_choose_df_lib(..., file_type="csv")

0 commit comments

Comments
 (0)