-
Notifications
You must be signed in to change notification settings - Fork 0
Add a PolarsDataFeed and deprecate the CSVDataFeed #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,125 @@ | ||||
| """Polars data feed implementation.""" | ||||
|
|
||||
| import logging | ||||
| from collections.abc import Generator | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
|
|
||||
| import polars as pl | ||||
|
|
||||
| from alphaflow import DataFeed | ||||
| from alphaflow.events.market_data_event import MarketDataEvent | ||||
|
|
||||
| logger = logging.getLogger(__name__) | ||||
|
|
||||
|
|
||||
| class PolarsDataFeed(DataFeed): | ||||
| """Data feed that loads market data from Polars dataframes.""" | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| df_or_file_path: Path | str | pl.DataFrame | pl.LazyFrame, | ||||
| *, | ||||
| col_timestamp: str = "Date", | ||||
| col_symbol: str = "Symbol", | ||||
| col_open: str = "Open", | ||||
| col_high: str = "High", | ||||
| col_low: str = "Low", | ||||
| col_close: str = "Close", | ||||
| col_volume: str = "Volume", | ||||
| ) -> None: | ||||
| """Initialize the Polars data feed. | ||||
|
|
||||
| Args: | ||||
| df_or_file_path: Polars dataframe or path to the Polars dataframe containing market data. | ||||
| col_timestamp: Name of the timestamp column. | ||||
| col_symbol: Name of the symbol column. | ||||
| col_open: Name of the open price column. | ||||
| col_high: Name of the high price column. | ||||
| col_low: Name of the low price column. | ||||
| col_close: Name of the close price column. | ||||
| col_volume: Name of the volume column. | ||||
|
|
||||
| """ | ||||
| self.df_or_file_path = df_or_file_path | ||||
| self._col_timestamp = col_timestamp | ||||
| self._col_symbol = col_symbol | ||||
| self._col_open = col_open | ||||
| self._col_high = col_high | ||||
| self._col_low = col_low | ||||
| self._col_close = col_close | ||||
| self._col_volume = col_volume | ||||
|
|
||||
| def run( | ||||
| self, | ||||
| symbol: str, | ||||
| start_timestamp: datetime | None, | ||||
| end_timestamp: datetime | None, | ||||
| ) -> Generator[MarketDataEvent, None, None]: | ||||
| """Load and yield market data events from the Polars dataframe. | ||||
|
|
||||
| Args: | ||||
| symbol: The ticker symbol to load data for. | ||||
| start_timestamp: Optional start time for filtering data. | ||||
| end_timestamp: Optional end time for filtering data. | ||||
|
|
||||
| Yields: | ||||
| MarketDataEvent objects containing OHLCV data. | ||||
|
|
||||
| Raises: | ||||
| ValueError: If required columns are missing from the Polars dataframe. | ||||
|
|
||||
| """ | ||||
| if isinstance(self.df_or_file_path, (str, Path)): | ||||
| df_path = Path(self.df_or_file_path) if isinstance(self.df_or_file_path, str) else self.df_or_file_path | ||||
| if df_path.suffix in {".parquet", ".parq"}: | ||||
| df = pl.read_parquet(df_path) | ||||
| df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime)) | ||||
|
||||
| df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring incorrectly refers to "path to the Polars dataframe" when it should say "path to a CSV or Parquet file". Polars DataFrames are in-memory data structures, not files on disk. The parameter can accept either a DataFrame/LazyFrame OR a file path to CSV/Parquet files.