Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions daft/datasets/common_crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def common_crawl(
crawl: The crawl identifier, e.g. "CC-MAIN-2025-33".
segment: Specific segment to fetch within the crawl. If not provided, defaults to all segments in the crawl.
content: Specifies the type of content to load. Options are:
- "raw" or "warc": Raw WARC files containing full HTTP responses
- "text" or "wet": Extracted plain text content
- "metadata" or "wat": Metadata about crawled pages
+ "raw" or "warc": Raw WARC files containing full HTTP responses
+ "text" or "wet": Extracted plain text content
+ "metadata" or "wat": Metadata about crawled pages
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that in VSCode rendering, it is happier with + signs then - or *. All three are valid in Markdown, and others like IPython don't care. So I vote for using these.

num_files: Limit the number of files to process. If not provided, processes all matching files.
io_config: IO configuration for accessing S3.
in_aws: Where to fetch the common crawl data from. If running in AWS, this must be set to True. If outside of AWS,
Expand Down
25 changes: 19 additions & 6 deletions daft/datatype.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import threading
import typing
import warnings
from types import GenericAlias
from types import GenericAlias, UnionType
from typing import TYPE_CHECKING, Any

from packaging.version import parse
Expand Down Expand Up @@ -144,16 +145,19 @@ def __init__(self) -> None:
)

@classmethod
def infer_from_type(cls, t: type | GenericAlias) -> DataType:
def infer_from_type(cls, t: type | GenericAlias | UnionType) -> DataType:
"""Infer Daft DataType from a Python type."""
# NOTE: Make sure this matches the logic in `Literal::from_pyobj` in Rust
# NOTE: The base type for Union is hidden, so it requires special handling
# TODO: TypeForm would cover everything: https://peps.python.org/pep-0747/

assert isinstance(t, (type, GenericAlias)), f"Input to DataType.infer_from_type must be a type, found: {t}"
assert isinstance(t, (type, GenericAlias, UnionType)) or typing.get_origin(t) is typing.Union, (
f"Input to DataType.infer_from_type must be a type, found {t} (type {type(t)})"
)

import datetime
import decimal
import importlib
import typing
from typing import is_typeddict

import daft.file
Expand Down Expand Up @@ -182,7 +186,16 @@ def check_type(type_or_path: type | str) -> bool:

return issubclass(origin, type_obj)

if check_type(type(None)):
# NOTE: This has to be first to handle the special case of typing.Union
if origin is typing.Union or check_type(UnionType): # type: ignore[comparison-overlap]
inner_types = set(DataType.infer_from_type(arg) for arg in args)
if len(inner_types) == 1:
return inner_types.pop()
elif len(inner_types) == 2 and cls.null() in inner_types:
return inner_types.difference([cls.null()]).pop()
else:
return cls.python()
elif check_type(type(None)):
return cls.null()
elif check_type(bool):
return cls.bool()
Expand Down Expand Up @@ -457,7 +470,7 @@ def _infer(cls, user_provided_type: DataTypeLike) -> DataType:
return user_provided_type
elif isinstance(user_provided_type, str):
return cls.from_sql(user_provided_type)
elif isinstance(user_provided_type, (type, GenericAlias)):
elif isinstance(user_provided_type, (type, GenericAlias, UnionType)):
return cls.infer_from_type(user_provided_type)
else:
raise TypeError("DataType._infer expects a DataType, string, or type")
Expand Down
140 changes: 82 additions & 58 deletions tests/test_datatype_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import decimal
from typing import NamedTuple, TypedDict
from typing import NamedTuple, Optional, TypedDict

import jax
import jaxtyping
Expand Down Expand Up @@ -37,6 +37,12 @@ class NestedPydanticModel(BaseModel):
active: bool


# Class-based TypedDict
class ClassFooBar(TypedDict):
foo: str
bar: int


class PydanticWithAlias(BaseModel):
model_config = {"serialize_by_alias": True}
full_name: str = Field(alias="name", serialization_alias="fullName")
Expand Down Expand Up @@ -79,78 +85,94 @@ class PydanticWithNamedTuple(BaseModel):
@pytest.mark.parametrize(
"user_provided_type, expected_datatype",
[
(type(None), dt.null()),
(bool, dt.bool()),
(str, dt.string()),
(bytes, dt.binary()),
(int, dt.int64()),
(float, dt.float64()),
(datetime.datetime, dt.timestamp(TimeUnit.us())),
(datetime.date, dt.date()),
(datetime.time, dt.time(TimeUnit.us())),
(datetime.timedelta, dt.duration(TimeUnit.us())),
(list, dt.list(dt.python())),
(list[str], dt.list(dt.string())),
(list[list], dt.list(dt.list(dt.python()))),
(list[list[str]], dt.list(dt.list(dt.string()))),
(
pytest.param(type(None), dt.null(), id="null"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added labels so its easier to tell which one fails

pytest.param(bool, dt.bool(), id="bool"),
pytest.param(str, dt.string(), id="str"),
pytest.param(bytes, dt.binary(), id="bytes"),
pytest.param(int, dt.int64(), id="int"),
pytest.param(float, dt.float64(), id="float"),
pytest.param(datetime.datetime, dt.timestamp(TimeUnit.us()), id="datetime"),
pytest.param(datetime.date, dt.date(), id="date"),
pytest.param(datetime.time, dt.time(TimeUnit.us()), id="time"),
pytest.param(datetime.timedelta, dt.duration(TimeUnit.us()), id="timedelta"),
pytest.param(list, dt.list(dt.python()), id="list_untyped"),
pytest.param(list[str], dt.list(dt.string()), id="list_str"),
pytest.param(list[list], dt.list(dt.list(dt.python())), id="list_list_untyped"),
pytest.param(list[list[str]], dt.list(dt.list(dt.string())), id="list_list_str"),
pytest.param(str | int, dt.python(), id="union"),
pytest.param(str | None, dt.string(), id="union_str_none"),
pytest.param(str | int | None, dt.python(), id="union_str_int_none"),
pytest.param(Optional[str], dt.string(), id="optional_str"), # noqa: UP045
pytest.param(
TypedDict("Foobar", {"foo": str, "bar": int}),
dt.struct({"foo": dt.string(), "bar": dt.int64()}),
id="typeddict_inline",
),
pytest.param(ClassFooBar, dt.struct({"foo": dt.string(), "bar": dt.int64()}), id="typeddict_class"),
pytest.param(dict, dt.map(dt.python(), dt.python()), id="dict_untyped"),
pytest.param(dict[str, str], dt.map(dt.string(), dt.string()), id="dict_str_str"),
pytest.param(tuple, dt.list(dt.python()), id="tuple_untyped"),
pytest.param(tuple[str, ...], dt.list(dt.string()), id="tuple_str_variadic"),
pytest.param(tuple[str, int], dt.struct({"_0": dt.string(), "_1": dt.int64()}), id="tuple_str_int_named"),
pytest.param(np.ndarray, dt.tensor(dt.python()), id="numpy_ndarray"),
pytest.param(torch.Tensor, dt.tensor(dt.python()), id="torch_tensor"),
pytest.param(torch.FloatTensor, dt.tensor(dt.float32()), id="torch_float32"),
pytest.param(torch.DoubleTensor, dt.tensor(dt.float64()), id="torch_float64"),
pytest.param(torch.ByteTensor, dt.tensor(dt.uint8()), id="torch_uint8"),
pytest.param(torch.CharTensor, dt.tensor(dt.int8()), id="torch_int8"),
pytest.param(torch.ShortTensor, dt.tensor(dt.int16()), id="torch_int16"),
pytest.param(torch.IntTensor, dt.tensor(dt.int32()), id="torch_int32"),
pytest.param(torch.LongTensor, dt.tensor(dt.int64()), id="torch_int64"),
pytest.param(torch.BoolTensor, dt.tensor(dt.bool()), id="torch_bool"),
*(
[]
if tensorflow is None
else [pytest.param(tensorflow.Tensor, dt.tensor(dt.python()), id="tensorflow_tensor")]
),
(dict, dt.map(dt.python(), dt.python())),
(dict[str, str], dt.map(dt.string(), dt.string())),
(tuple, dt.list(dt.python())),
(tuple[str, ...], dt.list(dt.string())),
(tuple[str, int], dt.struct({"_0": dt.string(), "_1": dt.int64()})),
(np.ndarray, dt.tensor(dt.python())),
(torch.Tensor, dt.tensor(dt.python())),
(torch.FloatTensor, dt.tensor(dt.float32())),
(torch.DoubleTensor, dt.tensor(dt.float64())),
(torch.ByteTensor, dt.tensor(dt.uint8())),
(torch.CharTensor, dt.tensor(dt.int8())),
(torch.ShortTensor, dt.tensor(dt.int16())),
(torch.IntTensor, dt.tensor(dt.int32())),
(torch.LongTensor, dt.tensor(dt.int64())),
(torch.BoolTensor, dt.tensor(dt.bool())),
*([] if tensorflow is None else [(tensorflow.Tensor, dt.tensor(dt.python()))]),
(jax.Array, dt.tensor(dt.python())),
(npt.NDArray[int], dt.tensor(dt.int64())),
(np.bool_, dt.bool()),
(np.int8, dt.int8()),
(np.uint8, dt.uint8()),
(np.int16, dt.int16()),
(np.uint16, dt.uint16()),
(np.int32, dt.int32()),
(np.uint32, dt.uint32()),
(np.int64, dt.int64()),
(np.uint64, dt.uint64()),
(np.float32, dt.float32()),
(np.float64, dt.float64()),
(np.datetime64, dt.timestamp(TimeUnit.us())),
(pandas.Series, dt.list(dt.python())),
(PIL.Image.Image, dt.image()),
(Series, dt.list(dt.python())),
(File, dt.file(MediaType.unknown())),
(VideoFile, dt.file(MediaType.video())),
(object, dt.python()),
pytest.param(jax.Array, dt.tensor(dt.python()), id="jax_array"),
pytest.param(npt.NDArray[int], dt.tensor(dt.int64()), id="numpy_ndarray_int"),
pytest.param(np.bool_, dt.bool(), id="numpy_bool"),
pytest.param(np.int8, dt.int8(), id="numpy_int8"),
pytest.param(np.uint8, dt.uint8(), id="numpy_uint8"),
pytest.param(np.int16, dt.int16(), id="numpy_int16"),
pytest.param(np.uint16, dt.uint16(), id="numpy_uint16"),
pytest.param(np.int32, dt.int32(), id="numpy_int32"),
pytest.param(np.uint32, dt.uint32(), id="numpy_uint32"),
pytest.param(np.int64, dt.int64(), id="numpy_int64"),
pytest.param(np.uint64, dt.uint64(), id="numpy_uint64"),
pytest.param(np.float32, dt.float32(), id="numpy_float32"),
pytest.param(np.float64, dt.float64(), id="numpy_float64"),
pytest.param(np.datetime64, dt.timestamp(TimeUnit.us()), id="numpy_datetime64"),
pytest.param(pandas.Series, dt.list(dt.python()), id="pandas_series"),
pytest.param(PIL.Image.Image, dt.image(), id="pil_image"),
pytest.param(Series, dt.list(dt.python()), id="daft_series"),
pytest.param(File, dt.file(MediaType.unknown()), id="daft_file"),
pytest.param(VideoFile, dt.file(MediaType.video()), id="daft_video_file"),
pytest.param(object, dt.python(), id="object_python"),
# Pydantic models
(SimplePydanticModel, dt.struct({"name": dt.string(), "age": dt.int64()})),
(
pytest.param(
SimplePydanticModel,
dt.struct({"name": dt.string(), "age": dt.int64()}),
id="pydantic_simple",
),
pytest.param(
NestedPydanticModel,
dt.struct(
{
"user": dt.struct({"name": dt.string(), "age": dt.int64()}),
"active": dt.bool(),
}
),
id="pydantic_nested",
),
# TODO: Uncomment this when we update to pydantic>=2.11 which supports `serialize_by_alias`
# (PydanticWithAlias, dt.struct({"fullName": dt.string(), "age": dt.int64()})),
(
pytest.param(
PydanticWithAliasNoSerializeByAlias,
dt.struct({"full_name": dt.string(), "user_age": dt.int64()}),
id="pydantic_alias_no_serialize",
),
(
pytest.param(
PydanticWithComputedField,
dt.struct(
{
Expand All @@ -159,17 +181,19 @@ class PydanticWithNamedTuple(BaseModel):
"full_name": dt.string(),
}
),
id="pydantic_computed_field",
),
(
pytest.param(
PydanticWithMixedTypes,
dt.struct(
{
"numbers": dt.list(dt.int64()),
"metadata": dt.map(dt.string(), dt.string()),
}
),
id="pydantic_mixed_types",
),
(EmptyPydanticModel, dt.struct({})),
pytest.param(EmptyPydanticModel, dt.struct({}), id="pydantic_empty"),
# TODO: uncomment once we support named tuples
# (SimpleNamedTuple, dt.struct({"foo": dt.string(), "bar": dt.int64()})),
# (PydanticWithNamedTuple, dt.struct({"values": dt.struct({"foo": dt.string(), "bar": dt.int64()})})),
Expand Down
Loading