Skip to content

Commit

Permalink
Fix: Set correct index in ObservationParquetReaderValue (#427)
Browse files Browse the repository at this point in the history
* fix: Drop station_id column and set correct index in ObservationParquetReaderValue

* Add test on index
  • Loading branch information
osundwajeff authored Feb 5, 2025
1 parent 0d43799 commit 8e7ab25
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
4 changes: 4 additions & 0 deletions django_project/gap/providers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,10 @@ def to_netcdf(self, suffix=".nc"):
try:
# Execute the DuckDB query and fetch data
df = self.conn.sql(self.query).df()
# Drop the station_id column
df = df.drop(columns=["station_id"])
# Set correct index
df = df.set_index(["date", "lat", "lon"])

# Convert DataFrame to Xarray Dataset
ds = xr.Dataset.from_dataframe(df)
Expand Down
74 changes: 74 additions & 0 deletions django_project/gap/tests/providers/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from unittest.mock import patch, MagicMock
import duckdb
import xarray as xr
import pandas as pd

from django.test import TestCase
from datetime import datetime
Expand Down Expand Up @@ -818,3 +820,75 @@ def test_duckdb_connection(self, mock_duckdb_connect):
mock_conn.load_extension.assert_any_call("httpfs")
mock_conn.install_extension.assert_any_call("spatial")
mock_conn.load_extension.assert_any_call("spatial")

@patch(
(
"gap.providers.observation."
"ObservationParquetReaderValue._get_file_remote_url")
)
@patch("gap.providers.observation.storages")
def test_to_netcdf_drops_station_id_and_sets_index(
self,
mock_storages,
mock_get_file_remote_url
):
"""Test that to_netcdf drops 'station_id' and sets index correctly."""
# Create mock DuckDB connection
mock_conn = MagicMock()
mock_conn.sql.return_value.df.return_value = pd.DataFrame({
"date": pd.date_range(start="2022-01-01", periods=5),
"lat": [1.0] * 5,
"lon": [2.0] * 5,
"station_id": ["A", "B", "C", "D", "E"], # Should be dropped
"temperature": [10, 15, 20, 25, 30] # Data column
})

# Mock DatasetReaderInput
location_input = DatasetReaderInput.from_point(Point(36.8, -1.3))

# Create ObservationParquetReaderValue
reader_value = ObservationParquetReaderValue(
val=mock_conn,
location_input=location_input,
attributes=[],
start_date=datetime(2022, 1, 1),
end_date=datetime(2022, 12, 31),
query="SELECT * FROM test"
)

# Mock file storage behavior
mock_get_file_remote_url.return_value = "s3://test-bucket/output.nc"
mock_s3_storage = MagicMock()
mock_storages.__getitem__.return_value = mock_s3_storage

# Run `to_netcdf`
netcdf_output = reader_value.to_netcdf()

# **Assertions**
# Ensure station_id column is removed
df_result = mock_conn.sql.call_args[0][0]
self.assertNotIn(
"station_id",
df_result, "station_id column was not removed"
)

# Ensure index is set correctly
ds = xr.Dataset.from_dataframe(
mock_conn.sql.return_value.df.return_value
)
if "index" in ds:
ds = ds.drop_vars("index")

# Use `.data` to extract raw NumPy arrays before assigning coordinates
ds = ds.assign_coords(date=("date", ds["date"].data))
ds = ds.assign_coords(lat=("lat", ds["lat"].data))
ds = ds.assign_coords(lon=("lon", ds["lon"].data))

self.assertEqual(
list(ds.coords.keys()),
["date", "lat", "lon"], "Incorrect index set"
)

# Ensure NetCDF file was saved
mock_s3_storage.save.assert_called_once()
self.assertEqual(netcdf_output, "s3://test-bucket/output.nc")

0 comments on commit 8e7ab25

Please sign in to comment.