From 2ae949aae3550e2db283108c8726514889349b8c Mon Sep 17 00:00:00 2001 From: Brandon Schabell Date: Tue, 23 Dec 2025 22:15:54 -0600 Subject: [PATCH] Add a PolarsDataFeed and deprecate the CSVDataFeed --- CHANGELOG.md | 4 + README.md | 6 +- alphaflow/data_feeds/__init__.py | 3 +- alphaflow/data_feeds/csv_data_feed.py | 92 ++----- alphaflow/data_feeds/polars_data_feed.py | 125 +++++++++ alphaflow/tests/test_alphaflow.py | 20 +- alphaflow/tests/test_analyzer.py | 8 +- alphaflow/tests/test_broker.py | 20 +- ...st_data_feeds.py => test_csv_data_feed.py} | 15 +- alphaflow/tests/test_polars_data_feed.py | 257 ++++++++++++++++++ alphaflow/tests/test_portfolio.py | 12 +- alphaflow/tests/test_strategy.py | 12 +- docs/api/data_feeds.md | 4 +- docs/getting_started.md | 2 +- pyproject.toml | 1 + uv.lock | 2 + 16 files changed, 463 insertions(+), 120 deletions(-) create mode 100644 alphaflow/data_feeds/polars_data_feed.py rename alphaflow/tests/{test_data_feeds.py => test_csv_data_feed.py} (90%) create mode 100644 alphaflow/tests/test_polars_data_feed.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f98b0b5..7341648 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `on_missing_price` config option - control behavior when `get_price()` cannot find data: `"raise"` (default), `"warn"`, or `"ignore"` +- `PolarsDataFeed` - Load historical data directly from Polars DataFrames ### Changed - Replaced `Make` with `just` for development commands +### Deprecated +- `CSVDataFeed` - Use `PolarsDataFeed` instead for loading data from CSV files + ## [0.2.0] - 2025-11-11 ### Added diff --git a/README.md b/README.md index f1037f0..73369e0 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ from alphaflow import AlphaFlow from alphaflow.brokers import SimpleBroker - from alphaflow.data_feeds import CSVDataFeed + from alphaflow.data_feeds import PolarsDataFeed from alphaflow.strategies import BuyAndHoldStrategy # 1. Initialize AlphaFlow @@ -135,8 +135,8 @@ # 2. Create DataFeed (e.g., CSV-based daily bars) flow.set_data_feed( - CSVDataFeed( - file_path="historical_data.csv", + PolarsDataFeed( + df_or_file_path="historical_data.csv", ) ) diff --git a/alphaflow/data_feeds/__init__.py b/alphaflow/data_feeds/__init__.py index 937813a..98cfb24 100644 --- a/alphaflow/data_feeds/__init__.py +++ b/alphaflow/data_feeds/__init__.py @@ -3,6 +3,7 @@ from alphaflow.data_feeds.alpha_vantage_data_feed import AlphaVantageFeed from alphaflow.data_feeds.csv_data_feed import CSVDataFeed from alphaflow.data_feeds.fmp_data_feed import FMPDataFeed +from alphaflow.data_feeds.polars_data_feed import PolarsDataFeed from alphaflow.data_feeds.polygon_data_feed import PolygonDataFeed -__all__ = ["AlphaVantageFeed", "CSVDataFeed", "FMPDataFeed", "PolygonDataFeed"] +__all__ = ["AlphaVantageFeed", "CSVDataFeed", "FMPDataFeed", "PolarsDataFeed", "PolygonDataFeed"] diff --git a/alphaflow/data_feeds/csv_data_feed.py b/alphaflow/data_feeds/csv_data_feed.py index 4d842be..8f0b82b 100644 --- a/alphaflow/data_feeds/csv_data_feed.py +++ b/alphaflow/data_feeds/csv_data_feed.py @@ -1,19 +1,17 @@ """CSV file data feed implementation.""" import logging -from collections.abc import Generator -from datetime import datetime from pathlib import Path -import polars as pl +from typing_extensions import deprecated -from alphaflow import DataFeed -from alphaflow.events.market_data_event import MarketDataEvent +from alphaflow.data_feeds.polars_data_feed import PolarsDataFeed logger = logging.getLogger(__name__) -class CSVDataFeed(DataFeed): +@deprecated("CSVDataFeed is deprecated and will be removed in a future release. Please use PolarsDataFeed instead.") +class CSVDataFeed(PolarsDataFeed): """Data feed that loads market data from CSV files.""" def __init__( @@ -30,6 +28,8 @@ def __init__( ) -> None: """Initialize the CSV data feed. + **Deprecated**: Use PolarsDataFeed instead. + Args: file_path: Path to the CSV file containing market data. col_timestamp: Name of the timestamp column. @@ -42,73 +42,13 @@ def __init__( """ self.file_path = Path(file_path) if isinstance(file_path, str) else file_path - self._col_timestamp = col_timestamp - self._col_symbol = col_symbol - self._col_open = col_open - self._col_high = col_high - self._col_low = col_low - self._col_close = col_close - self._col_volume = col_volume - - def run( - self, - symbol: str, - start_timestamp: datetime | None, - end_timestamp: datetime | None, - ) -> Generator[MarketDataEvent, None, None]: - """Load and yield market data events from the CSV file. - - Args: - symbol: The ticker symbol to load data for. - start_timestamp: Optional start time for filtering data. - end_timestamp: Optional end time for filtering data. - - Yields: - MarketDataEvent objects containing OHLCV data. - - Raises: - ValueError: If required columns are missing from the CSV. - - """ - logger.debug("Opening CSV file...") - df = pl.read_csv(self.file_path, try_parse_dates=True) - - required_cols = { - self._col_timestamp, - self._col_close, - self._col_high, - self._col_low, - self._col_open, - self._col_volume, - } - - missing_cols = required_cols.difference(df.columns) - if missing_cols: - raise ValueError(f"Missing columns: {missing_cols}") - - # Convert date column to datetime if needed (polars parses as date by default) - if df[self._col_timestamp].dtype == pl.Date: - df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime)) - - # Filter by symbol using polars - if self._col_symbol in df.columns: - df = df.filter(pl.col(self._col_symbol) == symbol) - - # Filter by timestamp bounds using polars - if start_timestamp: - df = df.filter(pl.col(self._col_timestamp) >= start_timestamp) - if end_timestamp: - df = df.filter(pl.col(self._col_timestamp) <= end_timestamp) - - # Convert to dicts once after all filtering - for row in df.iter_rows(named=True): - event = MarketDataEvent( - timestamp=row[self._col_timestamp], - symbol=symbol, - open=row[self._col_open], - high=row[self._col_high], - low=row[self._col_low], - close=row[self._col_close], - volume=row[self._col_volume], - ) - yield event + super().__init__( + df_or_file_path=file_path, + col_timestamp=col_timestamp, + col_symbol=col_symbol, + col_open=col_open, + col_high=col_high, + col_low=col_low, + col_close=col_close, + col_volume=col_volume, + ) diff --git a/alphaflow/data_feeds/polars_data_feed.py b/alphaflow/data_feeds/polars_data_feed.py new file mode 100644 index 0000000..cc53cea --- /dev/null +++ b/alphaflow/data_feeds/polars_data_feed.py @@ -0,0 +1,125 @@ +"""Polars data feed implementation.""" + +import logging +from collections.abc import Generator +from datetime import datetime +from pathlib import Path + +import polars as pl + +from alphaflow import DataFeed +from alphaflow.events.market_data_event import MarketDataEvent + +logger = logging.getLogger(__name__) + + +class PolarsDataFeed(DataFeed): + """Data feed that loads market data from Polars dataframes.""" + + def __init__( + self, + df_or_file_path: Path | str | pl.DataFrame | pl.LazyFrame, + *, + col_timestamp: str = "Date", + col_symbol: str = "Symbol", + col_open: str = "Open", + col_high: str = "High", + col_low: str = "Low", + col_close: str = "Close", + col_volume: str = "Volume", + ) -> None: + """Initialize the Polars data feed. + + Args: + df_or_file_path: Polars dataframe or path to the Polars dataframe containing market data. + col_timestamp: Name of the timestamp column. + col_symbol: Name of the symbol column. + col_open: Name of the open price column. + col_high: Name of the high price column. + col_low: Name of the low price column. + col_close: Name of the close price column. + col_volume: Name of the volume column. + + """ + self.df_or_file_path = df_or_file_path + self._col_timestamp = col_timestamp + self._col_symbol = col_symbol + self._col_open = col_open + self._col_high = col_high + self._col_low = col_low + self._col_close = col_close + self._col_volume = col_volume + + def run( + self, + symbol: str, + start_timestamp: datetime | None, + end_timestamp: datetime | None, + ) -> Generator[MarketDataEvent, None, None]: + """Load and yield market data events from the Polars dataframe. + + Args: + symbol: The ticker symbol to load data for. + start_timestamp: Optional start time for filtering data. + end_timestamp: Optional end time for filtering data. + + Yields: + MarketDataEvent objects containing OHLCV data. + + Raises: + ValueError: If required columns are missing from the Polars dataframe. + + """ + if isinstance(self.df_or_file_path, (str, Path)): + df_path = Path(self.df_or_file_path) if isinstance(self.df_or_file_path, str) else self.df_or_file_path + if df_path.suffix in {".parquet", ".parq"}: + df = pl.read_parquet(df_path) + df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime)) + elif df_path.suffix == ".csv": + df = pl.read_csv(df_path, try_parse_dates=True) + else: + raise ValueError(f"Unsupported file format: {df_path.suffix}") + elif isinstance(self.df_or_file_path, pl.LazyFrame): + df = self.df_or_file_path.collect() + else: + df = self.df_or_file_path + + required_cols = { + self._col_timestamp, + self._col_close, + self._col_high, + self._col_low, + self._col_open, + self._col_volume, + } + + missing_cols = required_cols.difference(df.columns) + if missing_cols: + raise ValueError(f"Missing columns: {missing_cols}") + + # Convert date column to datetime if needed (polars parses as date by default) + if df[self._col_timestamp].dtype == pl.Date: + df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime)) + + # Filter by symbol using polars + if self._col_symbol in df.columns: + df = df.filter(pl.col(self._col_symbol) == symbol) + + # Filter by timestamp bounds using polars + if start_timestamp: + df = df.filter(pl.col(self._col_timestamp) >= start_timestamp) + if end_timestamp: + df = df.filter(pl.col(self._col_timestamp) <= end_timestamp) + + # Convert to dicts once after all filtering + for row in df.sort(by=self._col_timestamp).iter_rows(named=True): + event = MarketDataEvent( + timestamp=row[self._col_timestamp], + symbol=symbol, + open=row[self._col_open], + high=row[self._col_high], + low=row[self._col_low], + close=row[self._col_close], + volume=row[self._col_volume], + ) + yield event diff --git a/alphaflow/tests/test_alphaflow.py b/alphaflow/tests/test_alphaflow.py index f6a934d..7657dd7 100644 --- a/alphaflow/tests/test_alphaflow.py +++ b/alphaflow/tests/test_alphaflow.py @@ -8,7 +8,7 @@ from alphaflow import AlphaFlow from alphaflow.analyzers import DefaultAnalyzer from alphaflow.brokers import SimpleBroker -from alphaflow.data_feeds import CSVDataFeed +from alphaflow.data_feeds import PolarsDataFeed from alphaflow.strategies import BuyAndHoldStrategy @@ -50,7 +50,7 @@ def test_alphaflow_add_equity() -> None: def test_alphaflow_set_data_feed() -> None: """Test setting the data feed.""" af = AlphaFlow() - data_feed = CSVDataFeed("alphaflow/tests/data/AAPL.csv") + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") af.set_data_feed(data_feed) @@ -174,7 +174,7 @@ def test_alphaflow_set_backtest_end_timestamp_string() -> None: def test_alphaflow_get_timestamps() -> None: """Test getting all timestamps from loaded data.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -192,7 +192,7 @@ def test_alphaflow_get_timestamps() -> None: def test_alphaflow_get_price() -> None: """Test getting price for a symbol at a timestamp.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -207,7 +207,7 @@ def test_alphaflow_get_price() -> None: def test_alphaflow_get_price_raises_error_for_missing_data() -> None: """Test get_price raises error when no data exists after timestamp.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -222,7 +222,7 @@ def test_alphaflow_get_price_raises_error_for_missing_data() -> None: def test_alphaflow_on_missing_price_warn(caplog: pytest.LogCaptureFixture) -> None: """Test that on_missing_price='warn' logs a warning and returns 0.0.""" af = AlphaFlow(on_missing_price="warn") - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -239,7 +239,7 @@ def test_alphaflow_on_missing_price_warn(caplog: pytest.LogCaptureFixture) -> No def test_alphaflow_on_missing_price_ignore() -> None: """Test that on_missing_price='ignore' silently returns 0.0.""" af = AlphaFlow(on_missing_price="ignore") - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -269,7 +269,7 @@ def test_alphaflow_run_raises_error_without_data_feed() -> None: def test_alphaflow_run_raises_error_for_live_trading() -> None: """Test run raises error for live trading (not implemented).""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) @@ -280,7 +280,7 @@ def test_alphaflow_run_raises_error_for_live_trading() -> None: def test_alphaflow_complete_backtest_flow() -> None: """Test complete backtest flow with all components.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_benchmark("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) @@ -307,7 +307,7 @@ def test_alphaflow_complete_backtest_flow() -> None: def test_simple_backtest() -> None: """Test a simple buy-and-hold backtest with AAPL.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) af.set_broker(SimpleBroker()) diff --git a/alphaflow/tests/test_analyzer.py b/alphaflow/tests/test_analyzer.py index 4b1a13b..1c5a7e5 100644 --- a/alphaflow/tests/test_analyzer.py +++ b/alphaflow/tests/test_analyzer.py @@ -7,7 +7,7 @@ from alphaflow import AlphaFlow from alphaflow.analyzers import DefaultAnalyzer from alphaflow.brokers import SimpleBroker -from alphaflow.data_feeds import CSVDataFeed +from alphaflow.data_feeds import PolarsDataFeed from alphaflow.strategies import BuyAndHoldStrategy @@ -44,7 +44,7 @@ def test_default_analyzer_topic_subscriptions() -> None: def test_default_analyzer_with_backtest() -> None: """Test analyzer collects data during a backtest.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) af.set_broker(SimpleBroker()) @@ -69,7 +69,7 @@ def test_default_analyzer_generate_plot() -> None: plot_path = Path(tmpdir) / "test_plot.html" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) af.set_broker(SimpleBroker()) @@ -91,7 +91,7 @@ def test_default_analyzer_with_benchmark() -> None: plot_path = Path(tmpdir) / "benchmark_plot.html" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_benchmark("AAPL") # Use AAPL as its own benchmark af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) diff --git a/alphaflow/tests/test_broker.py b/alphaflow/tests/test_broker.py index 7f76390..42f05a4 100644 --- a/alphaflow/tests/test_broker.py +++ b/alphaflow/tests/test_broker.py @@ -7,7 +7,7 @@ from alphaflow import AlphaFlow from alphaflow.brokers import SimpleBroker from alphaflow.commission_models import FixedCommissionModel, PerShareCommissionModel -from alphaflow.data_feeds import CSVDataFeed +from alphaflow.data_feeds import PolarsDataFeed from alphaflow.enums import OrderType, Side, Topic from alphaflow.events import OrderEvent from alphaflow.events.market_data_event import MarketDataEvent @@ -32,7 +32,7 @@ def test_simple_broker_initialization_custom_margin() -> None: def test_broker_executes_valid_buy_order() -> None: """Test broker executes a valid buy order.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -75,7 +75,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_rejects_insufficient_buying_power() -> None: """Test broker rejects orders with insufficient buying power.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10) # Very low cash - only $10 af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -113,7 +113,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_rejects_short_sell() -> None: """Test broker rejects short selling (selling without position).""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -150,7 +150,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_allows_valid_sell() -> None: """Test broker allows selling shares we own.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -191,7 +191,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_ignores_non_order_events() -> None: """Test broker ignores events that aren't orders.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -229,7 +229,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_with_margin() -> None: """Test broker allows larger positions with margin.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(1000) # Limited cash af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -270,7 +270,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_with_slippage_and_commission() -> None: """Test broker with both slippage and commission models.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -332,7 +332,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_slippage_affects_buying_power_validation() -> None: """Test that slippage is considered when validating buying power.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_data_start_timestamp(datetime(1980, 12, 25)) af.run() @@ -377,7 +377,7 @@ def read_event(self, event): # type: ignore[no-untyped-def] def test_broker_commission_affects_buying_power_validation() -> None: """Test that commission is considered when validating buying power.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_data_start_timestamp(datetime(1980, 12, 25)) af.run() diff --git a/alphaflow/tests/test_data_feeds.py b/alphaflow/tests/test_csv_data_feed.py similarity index 90% rename from alphaflow/tests/test_data_feeds.py rename to alphaflow/tests/test_csv_data_feed.py index 7b2a952..15a5106 100644 --- a/alphaflow/tests/test_data_feeds.py +++ b/alphaflow/tests/test_csv_data_feed.py @@ -1,4 +1,4 @@ -"""Tests for data feeds.""" +"""Tests for CSV data feeds.""" from datetime import datetime @@ -164,3 +164,16 @@ def test_csv_data_feed_empty_range() -> None: ) assert len(events) == 0 + + +def test_deprecated_csv_data_feed() -> None: + """Test that using the deprecated CSVDataFeed issues a warning.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + CSVDataFeed("alphaflow/tests/data/AAPL.csv") + + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "CSVDataFeed is deprecated" in str(w[-1].message) diff --git a/alphaflow/tests/test_polars_data_feed.py b/alphaflow/tests/test_polars_data_feed.py new file mode 100644 index 0000000..542ceec --- /dev/null +++ b/alphaflow/tests/test_polars_data_feed.py @@ -0,0 +1,257 @@ +"""Tests for Polars data feeds.""" + +from datetime import datetime +from pathlib import Path + +import polars as pl + +from alphaflow.data_feeds import PolarsDataFeed +from alphaflow.events import MarketDataEvent + + +def test_polars_data_feed_initialization() -> None: + """Test PolarsDataFeed initialization.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + assert isinstance(data_feed.df_or_file_path, str) + assert data_feed.df_or_file_path == "alphaflow/tests/data/AAPL.csv" + + +def test_polars_data_feed_run_yields_market_data_events() -> None: + """Test PolarsDataFeed yields MarketDataEvent objects.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + assert len(events) > 0 + assert all(isinstance(event, MarketDataEvent) for event in events) + + +def test_polars_data_feed_events_have_correct_symbol() -> None: + """Test all events have the requested symbol.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="TEST_SYMBOL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + assert all(event.symbol == "TEST_SYMBOL" for event in events) + + +def test_polars_data_feed_events_sorted_by_timestamp() -> None: + """Test events are yielded in chronological order.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1981, 1, 31), + ) + ) + + timestamps = [event.timestamp for event in events] + assert timestamps == sorted(timestamps) + + +def test_polars_data_feed_respects_start_timestamp() -> None: + """Test data feed only yields events after start timestamp.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + start_timestamp = datetime(1981, 1, 1) + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=start_timestamp, + end_timestamp=datetime(1981, 1, 31), + ) + ) + + assert all(event.timestamp >= start_timestamp for event in events) + + +def test_polars_data_feed_respects_end_timestamp() -> None: + """Test data feed only yields events before end timestamp.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + end_timestamp = datetime(1981, 1, 15) + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=end_timestamp, + ) + ) + + assert all(event.timestamp <= end_timestamp for event in events) + + +def test_polars_data_feed_event_has_all_ohlcv_fields() -> None: + """Test MarketDataEvent has open, high, low, close, volume.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + # Check first event has all required fields + event = events[0] + assert hasattr(event, "open") + assert hasattr(event, "high") + assert hasattr(event, "low") + assert hasattr(event, "close") + assert hasattr(event, "volume") + assert hasattr(event, "timestamp") + assert hasattr(event, "symbol") + + +def test_polars_data_feed_prices_are_positive() -> None: + """Test all OHLC prices are positive.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + for event in events: + assert event.open > 0 + assert event.high > 0 + assert event.low > 0 + assert event.close > 0 + + +def test_polars_data_feed_high_low_relationship() -> None: + """Test high >= low for all events.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1981, 1, 31), + ) + ) + + for event in events: + assert event.high >= event.low + + +def test_polars_data_feed_empty_range() -> None: + """Test data feed with date range that has no data.""" + data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv") + + # Use a date range before any data exists + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1970, 1, 1), + end_timestamp=datetime(1970, 1, 31), + ) + ) + + assert len(events) == 0 + + +def test_polars_data_feed_initialization_with_dataframe() -> None: + """Test PolarsDataFeed initialization with a Polars DataFrame.""" + df = pl.read_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + data_feed = PolarsDataFeed(df) + + assert isinstance(data_feed.df_or_file_path, pl.DataFrame) + + +def test_polars_data_feed_initialization_with_lazyframe() -> None: + """Test PolarsDataFeed initialization with a Polars LazyFrame.""" + lf = pl.scan_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + data_feed = PolarsDataFeed(lf) + + assert isinstance(data_feed.df_or_file_path, pl.LazyFrame) + + +def test_polars_data_feed_initialization_with_parquet(tmp_path: Path) -> None: + """Test PolarsDataFeed initialization with a parquet file path.""" + # Create a temporary parquet file from CSV data + df = pl.read_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + parquet_path = tmp_path / "AAPL.parquet" + df.write_parquet(parquet_path) + + data_feed = PolarsDataFeed(parquet_path) + + assert isinstance(data_feed.df_or_file_path, Path) + assert data_feed.df_or_file_path == parquet_path + + +def test_polars_data_feed_run_with_dataframe() -> None: + """Test PolarsDataFeed run with a Polars DataFrame yields correct events.""" + df = pl.read_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + data_feed = PolarsDataFeed(df) + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + assert len(events) > 0 + assert all(isinstance(event, MarketDataEvent) for event in events) + assert all(event.symbol == "AAPL" for event in events) + + +def test_polars_data_feed_run_with_lazyframe() -> None: + """Test PolarsDataFeed run with a Polars LazyFrame yields correct events.""" + lf = pl.scan_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + data_feed = PolarsDataFeed(lf) + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + assert len(events) > 0 + assert all(isinstance(event, MarketDataEvent) for event in events) + assert all(event.symbol == "AAPL" for event in events) + + +def test_polars_data_feed_run_with_parquet(tmp_path: Path) -> None: + """Test PolarsDataFeed run with a parquet file yields correct events.""" + # Create a temporary parquet file from CSV data + df = pl.read_csv("alphaflow/tests/data/AAPL.csv", try_parse_dates=True) + parquet_path = tmp_path / "AAPL.parquet" + df.write_parquet(parquet_path) + + data_feed = PolarsDataFeed(parquet_path) + + events = list( + data_feed.run( + symbol="AAPL", + start_timestamp=datetime(1980, 12, 25), + end_timestamp=datetime(1980, 12, 31), + ) + ) + + assert len(events) > 0 + assert all(isinstance(event, MarketDataEvent) for event in events) + assert all(event.symbol == "AAPL" for event in events) diff --git a/alphaflow/tests/test_portfolio.py b/alphaflow/tests/test_portfolio.py index 310fffa..c2adb92 100644 --- a/alphaflow/tests/test_portfolio.py +++ b/alphaflow/tests/test_portfolio.py @@ -5,7 +5,7 @@ import pytest from alphaflow import AlphaFlow -from alphaflow.data_feeds import CSVDataFeed +from alphaflow.data_feeds import PolarsDataFeed from alphaflow.events import FillEvent from alphaflow.events.market_data_event import MarketDataEvent @@ -68,7 +68,7 @@ def test_portfolio_update_position() -> None: def test_portfolio_get_position_value() -> None: """Test calculating position value at a timestamp.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -89,7 +89,7 @@ def test_portfolio_get_position_value() -> None: def test_portfolio_get_positions_value() -> None: """Test calculating total positions value.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -108,7 +108,7 @@ def test_portfolio_get_positions_value() -> None: def test_portfolio_get_portfolio_value() -> None: """Test calculating total portfolio value (cash + positions).""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(5000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -128,7 +128,7 @@ def test_portfolio_get_portfolio_value() -> None: def test_portfolio_get_buying_power() -> None: """Test calculating buying power with margin.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_cash(10000) af.set_data_start_timestamp(datetime(1980, 12, 25)) @@ -164,7 +164,7 @@ def test_portfolio_get_benchmark_values_no_benchmark() -> None: def test_portfolio_get_benchmark_values_with_benchmark() -> None: """Test getting benchmark values when benchmark is set.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.set_benchmark("AAPL") af.set_cash(10000) diff --git a/alphaflow/tests/test_strategy.py b/alphaflow/tests/test_strategy.py index cd1cccb..3b5aabe 100644 --- a/alphaflow/tests/test_strategy.py +++ b/alphaflow/tests/test_strategy.py @@ -4,7 +4,7 @@ from alphaflow import AlphaFlow from alphaflow.brokers import SimpleBroker -from alphaflow.data_feeds import CSVDataFeed +from alphaflow.data_feeds import PolarsDataFeed from alphaflow.enums import Topic from alphaflow.strategies import BuyAndHoldStrategy @@ -27,7 +27,7 @@ def test_buy_and_hold_topic_subscriptions() -> None: def test_buy_and_hold_initial_purchase() -> None: """Test strategy makes initial purchase.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) af.set_broker(SimpleBroker()) @@ -46,7 +46,7 @@ def test_buy_and_hold_initial_purchase() -> None: def test_buy_and_hold_rebalancing() -> None: """Test strategy rebalances periodically.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") strategy = BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0) af.add_strategy(strategy) @@ -66,7 +66,7 @@ def test_buy_and_hold_rebalancing() -> None: def test_buy_and_hold_partial_allocation() -> None: """Test strategy with partial portfolio allocation.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=0.5)) af.set_broker(SimpleBroker()) @@ -87,7 +87,7 @@ def test_buy_and_hold_partial_allocation() -> None: def test_buy_and_hold_filters_events_outside_backtest() -> None: """Test strategy ignores events outside backtest window.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) af.add_equity("AAPL") strategy = BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0) af.add_strategy(strategy) @@ -108,7 +108,7 @@ def test_buy_and_hold_filters_events_outside_backtest() -> None: def test_buy_and_hold_filters_wrong_symbol() -> None: """Test strategy ignores events for other symbols.""" af = AlphaFlow() - af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv")) + af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv")) # Add AAPL to universe but strategy only trades AAPL af.add_equity("AAPL") af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0)) diff --git a/docs/api/data_feeds.md b/docs/api/data_feeds.md index a5e8408..3e5416f 100644 --- a/docs/api/data_feeds.md +++ b/docs/api/data_feeds.md @@ -1,8 +1,8 @@ # Data Feeds -## CSVDataFeed +## PolarsDataFeed -::: alphaflow.data_feeds.CSVDataFeed +::: alphaflow.data_feeds.PolarsDataFeed options: show_root_heading: true heading_level: 3 diff --git a/docs/getting_started.md b/docs/getting_started.md index 5a9505d..7eb3901 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -86,7 +86,7 @@ AlphaFlow supports multiple data providers: - **PolygonDataFeed**: Real-time and historical data from [Polygon.io](https://polygon.io/) - **AlphaVantageFeed**: Free market data from [Alpha Vantage](https://www.alphavantage.co/) -- **CSVDataFeed**: Load data from local CSV files for testing or custom data sources +- **PolarsDataFeed**: Load data from local dataframes for testing or custom data sources Example using Polygon: diff --git a/pyproject.toml b/pyproject.toml index e1765ab..0f11c2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "polars>=1.35.2", "plotly>=5.18.0", "python-dotenv>=1.0.0", + "typing-extensions>=4.15.0", ] [project.urls] diff --git a/uv.lock b/uv.lock index d7c08cd..00cfdae 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,7 @@ dependencies = [ { name = "plotly" }, { name = "polars" }, { name = "python-dotenv" }, + { name = "typing-extensions" }, ] [package.dev-dependencies] @@ -37,6 +38,7 @@ requires-dist = [ { name = "plotly", specifier = ">=5.18.0" }, { name = "polars", specifier = ">=1.35.2" }, { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "typing-extensions", specifier = ">=4.15.0" }, ] [package.metadata.requires-dev]