From 996661f9b19371156feba912760e8f8a4223cbad Mon Sep 17 00:00:00 2001 From: Mostafa Farrag Date: Fri, 17 Jan 2025 15:06:00 +0100 Subject: [PATCH] add new parameter `quantities_names` to the `TimModel` class --- hydrolib/core/dflowfm/tim/models.py | 41 ++++++++++++++++++++++++++--- tests/dflowfm/test_tim.py | 14 ++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/hydrolib/core/dflowfm/tim/models.py b/hydrolib/core/dflowfm/tim/models.py index 1cdadce62..73ffb11e3 100644 --- a/hydrolib/core/dflowfm/tim/models.py +++ b/hydrolib/core/dflowfm/tim/models.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from pandas import DataFrame from pydantic.v1 import Field @@ -31,6 +31,8 @@ class TimModel(ParsableFileModel): Header comments from the .tim file. timeseries : List[TimRecord] A list of TimRecord objects, each containing a time value and associated data. + quantities_names : Optional[List[str]] + List of names for the quantities in the timeseries. Methods: -------- @@ -44,6 +46,10 @@ class TimModel(ParsableFileModel): Returns the parser callable for .tim files. _validate_timeseries_values(cls, v: List[TimRecord]) -> List[TimRecord] Validates the timeseries data. + as_dataframe(columns: List[Any] = None) -> DataFrame + Returns the timeseries as a pandas DataFrame. + _validate_quantities_names(cls, v, values) -> List[str] + Validates the quantities_names equals to the values or each record. Args: ----- @@ -106,6 +112,8 @@ class TimModel(ParsableFileModel): timeseries: List[TimRecord] = Field(default_factory=list) """List[TimRecord]: A list containing the timeseries.""" + quantities_names: Optional[List[str]] = Field(default=None) + @classmethod def _ext(cls) -> str: return ".tim" @@ -171,9 +179,28 @@ def _raise_error_if_duplicate_time(timeseries: List[TimRecord]) -> None: ) seen_times.add(timrecord.time) - def as_dataframe(self) -> DataFrame: + @validator("quantities_names") + def _validate_quantities_names(cls, v, values): + """Validate if the amount of quantities_names match the amount of columns in the timeseries. + + The validator compared the amount of quantities_names with the amount of columns in the first record of + the timeseries. + """ + if v is not None: + first_records_data = values["timeseries"][0].data + if len(v) != len(first_records_data): + raise ValueError( + f"The number of quantities_names ({len(v)}) must match the number of columns in the Tim file ({len(first_records_data)})." + ) + return v + + def as_dataframe(self, columns: List[Any] = None) -> DataFrame: """Return the timeseries as a pandas DataFrame. + Args: + columns (List[Any, str], optional, Defaults to None): + The column names for the DataFrame. + Returns: DataFrame: The timeseries as a pandas DataFrame. @@ -187,7 +214,15 @@ def as_dataframe(self) -> DataFrame: 10.0 1.232 2.343 3.454 20.0 4.565 5.676 6.787 30.0 1.500 2.600 3.700 + + To add column names to the DataFrame: + >>> df = tim_model.as_dataframe(columns=["Column1", "Column2", "Column3"]) + >>> print(df) + Column1 Column2 Column3 + 10.0 1.232 2.343 3.454 + 20.0 4.565 5.676 6.787 + 30.0 1.500 2.600 3.700 """ time_series = [record.data for record in self.timeseries] index = [record.time for record in self.timeseries] - return DataFrame(time_series, index=index) + return DataFrame(time_series, index=index, columns=columns) diff --git a/tests/dflowfm/test_tim.py b/tests/dflowfm/test_tim.py index a70ea6d7d..22ca9af8f 100644 --- a/tests/dflowfm/test_tim.py +++ b/tests/dflowfm/test_tim.py @@ -127,6 +127,20 @@ def test_as_dataframe(self): ] assert df.loc[:, 0].to_list() == vals + df = model.as_dataframe(columns=["data"]) + assert df.columns.to_list() == ["data"] + + def test_with_quantities_names(self): + model = TimModel( + timeseries=self.single_data_for_timeseries_floats, quantities_names=["a"] + ) + assert model.quantities_names == ["a"] + with pytest.raises(ValueError): + TimModel( + timeseries=self.single_data_for_timeseries_floats, + quantities_names=["a", "b"], + ) + @pytest.mark.parametrize( "input_data, reference_path", [