diff --git a/arro3-core/python/arro3/core/_data_type.pyi b/arro3-core/python/arro3/core/_data_type.pyi index 9ca849f..6cf36a9 100644 --- a/arro3-core/python/arro3/core/_data_type.pyi +++ b/arro3-core/python/arro3/core/_data_type.pyi @@ -154,7 +154,8 @@ class DataType: Args: unit: one of `'s'` [second], `'ms'` [millisecond], `'us'` [microsecond], or `'ns'` [nanosecond] - tz: Time zone name. None indicates time zone naive. Defaults to None. + tz: Time zone name. None indicates time zone naive. Defaults to None. Supported + values are IANA time-zones, see `pytz.all_timezones` for a list of supported values. Returns: _description_ diff --git a/pyo3-arrow/src/array.rs b/pyo3-arrow/src/array.rs index 6a0539a..96e36cd 100644 --- a/pyo3-arrow/src/array.rs +++ b/pyo3-arrow/src/array.rs @@ -8,12 +8,15 @@ use arrow_array::types::{ use arrow_array::{ Array, ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Datum, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, }; use arrow_cast::cast; use arrow_cast::display::ArrayFormatter; -use arrow_schema::{ArrowError, DataType, Field, FieldRef}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, TimeUnit}; use arrow_select::concat::concat; use arrow_select::take::take; +use chrono::{FixedOffset, Utc}; use numpy::PyUntypedArray; use pyo3::exceptions::{PyIndexError, PyNotImplementedError, PyValueError}; use pyo3::intern; @@ -212,6 +215,7 @@ impl PyArray { "type must be passed for non-Arrow input", ))? .into_inner(); + let array: ArrayRef = match field.data_type() { DataType::Float32 => impl_primitive!(f32, Float32Type), DataType::Float64 => impl_primitive!(f64, Float64Type), @@ -284,6 +288,57 @@ impl PyArray { .collect::>(); Arc::new(StringViewArray::from(slices)) } + DataType::Timestamp(unit, tz) => { + // We normalize all datetimes to datetimes in UTC. + let values: Vec>> = match tz { + Some(_) => { + let vs: Vec>> = obj.extract()?; + vs.into_iter() + .map(|v| v.map(|dt| dt.with_timezone(&Utc))) + .collect() + } + None => { + let vs: Vec> = obj.extract()?; + vs.into_iter() + .map(|v| v.map(|naive| naive.and_utc())) + .collect() + } + }; + match unit { + TimeUnit::Second => { + let values: Vec<_> = + values.iter().map(|v| v.map(|x| x.timestamp())).collect(); + Arc::new(TimestampSecondArray::from(values).with_timezone_opt(tz.clone())) + } + TimeUnit::Millisecond => { + let values: Vec> = values + .iter() + .map(|v| v.map(|x| x.timestamp_millis())) + .collect(); + Arc::new( + TimestampMillisecondArray::from(values).with_timezone_opt(tz.clone()), + ) + } + TimeUnit::Microsecond => { + let values: Vec> = values + .iter() + .map(|v| v.map(|x| x.timestamp_micros())) + .collect(); + Arc::new( + TimestampMicrosecondArray::from(values).with_timezone_opt(tz.clone()), + ) + } + TimeUnit::Nanosecond => { + let values: Vec> = values + .iter() + .map(|v| v.map(|x| x.timestamp_nanos_opt().unwrap())) + .collect(); + Arc::new( + TimestampNanosecondArray::from(values).with_timezone_opt(tz.clone()), + ) + } + } + } dt => { return Err(PyNotImplementedError::new_err(format!( "Array constructor for {dt} not yet implemented." diff --git a/tests/core/test_arrays/test_array_datetime.py b/tests/core/test_arrays/test_array_datetime.py new file mode 100644 index 0000000..21dfd38 --- /dev/null +++ b/tests/core/test_arrays/test_array_datetime.py @@ -0,0 +1,51 @@ +import zoneinfo +from datetime import datetime, timezone + +import pytest +from arro3.core import Array, DataType + + +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_array_timestamp_timezone(unit): + """Test that an array with timestamp type can be created with different units.""" + dt = datetime(1999, 8, 7, 11, 12, 13, 141516) + arr = Array([dt, None], type=DataType.timestamp(unit)) + + result: datetime = arr.to_pylist()[0] + + assert result.replace(microsecond=0) == dt.replace(microsecond=0) + + if unit == "s": + assert result.microsecond == 0 + + if unit == "ms": + assert result.microsecond == 141000 + + if unit == "us" or unit == "ns": + assert result.microsecond == dt.microsecond + + +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +@pytest.mark.parametrize("tz_name", ["UTC", "America/Chicago", "Europe/Madrid"]) +def test_array_timestamp_tz(unit, tz_name): + """Test that an array with timestamp type can be created with different units and timezone.""" + dt = datetime(1999, 8, 7, 11, 12, 13, 141516) + + tzinfo = zoneinfo.ZoneInfo(tz_name) + expected: datetime = dt.astimezone(timezone(tzinfo.utcoffset(dt))) + + arr = Array([expected, None], type=DataType.timestamp(unit, tz=tz_name)) + result: datetime = arr.to_pylist()[0] + + # compare without microseconds because its more direct. + assert result.replace(microsecond=0) == expected.replace(microsecond=0) + assert result.tzinfo.utcoffset(dt) == expected.tzinfo.utcoffset(dt) + + if unit == "s": + assert result.microsecond == 0 + + if unit == "ms": + assert result.microsecond == 141000 + + if unit == "us" or unit == "ns": + assert result.microsecond == expected.microsecond