Skip to content

Commit b92afd3

Browse files
Add a PolarsDataFeed and deprecate the CSVDataFeed
1 parent e098489 commit b92afd3

16 files changed

+463
-120
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111
- `on_missing_price` config option - control behavior when `get_price()` cannot find data: `"raise"` (default), `"warn"`, or `"ignore"`
12+
- `PolarsDataFeed` - Load historical data directly from Polars DataFrames
1213

1314
### Changed
1415
- Replaced `Make` with `just` for development commands
1516

17+
### Deprecated
18+
- `CSVDataFeed` - Use `PolarsDataFeed` instead for loading data from CSV files
19+
1620
## [0.2.0] - 2025-11-11
1721

1822
### Added

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124
125125
from alphaflow import AlphaFlow
126126
from alphaflow.brokers import SimpleBroker
127-
from alphaflow.data_feeds import CSVDataFeed
127+
from alphaflow.data_feeds import PolarsDataFeed
128128
from alphaflow.strategies import BuyAndHoldStrategy
129129
130130
# 1. Initialize AlphaFlow
@@ -135,8 +135,8 @@
135135
136136
# 2. Create DataFeed (e.g., CSV-based daily bars)
137137
flow.set_data_feed(
138-
CSVDataFeed(
139-
file_path="historical_data.csv",
138+
PolarsDataFeed(
139+
df_or_file_path="historical_data.csv",
140140
)
141141
)
142142

alphaflow/data_feeds/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from alphaflow.data_feeds.alpha_vantage_data_feed import AlphaVantageFeed
44
from alphaflow.data_feeds.csv_data_feed import CSVDataFeed
55
from alphaflow.data_feeds.fmp_data_feed import FMPDataFeed
6+
from alphaflow.data_feeds.polars_data_feed import PolarsDataFeed
67
from alphaflow.data_feeds.polygon_data_feed import PolygonDataFeed
78

8-
__all__ = ["AlphaVantageFeed", "CSVDataFeed", "FMPDataFeed", "PolygonDataFeed"]
9+
__all__ = ["AlphaVantageFeed", "CSVDataFeed", "FMPDataFeed", "PolarsDataFeed", "PolygonDataFeed"]
Lines changed: 16 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
"""CSV file data feed implementation."""
22

33
import logging
4-
from collections.abc import Generator
5-
from datetime import datetime
64
from pathlib import Path
75

8-
import polars as pl
6+
from typing_extensions import deprecated
97

10-
from alphaflow import DataFeed
11-
from alphaflow.events.market_data_event import MarketDataEvent
8+
from alphaflow.data_feeds.polars_data_feed import PolarsDataFeed
129

1310
logger = logging.getLogger(__name__)
1411

1512

16-
class CSVDataFeed(DataFeed):
13+
@deprecated("CSVDataFeed is deprecated and will be removed in a future release. Please use PolarsDataFeed instead.")
14+
class CSVDataFeed(PolarsDataFeed):
1715
"""Data feed that loads market data from CSV files."""
1816

1917
def __init__(
@@ -30,6 +28,8 @@ def __init__(
3028
) -> None:
3129
"""Initialize the CSV data feed.
3230
31+
**Deprecated**: Use PolarsDataFeed instead.
32+
3333
Args:
3434
file_path: Path to the CSV file containing market data.
3535
col_timestamp: Name of the timestamp column.
@@ -42,73 +42,13 @@ def __init__(
4242
4343
"""
4444
self.file_path = Path(file_path) if isinstance(file_path, str) else file_path
45-
self._col_timestamp = col_timestamp
46-
self._col_symbol = col_symbol
47-
self._col_open = col_open
48-
self._col_high = col_high
49-
self._col_low = col_low
50-
self._col_close = col_close
51-
self._col_volume = col_volume
52-
53-
def run(
54-
self,
55-
symbol: str,
56-
start_timestamp: datetime | None,
57-
end_timestamp: datetime | None,
58-
) -> Generator[MarketDataEvent, None, None]:
59-
"""Load and yield market data events from the CSV file.
60-
61-
Args:
62-
symbol: The ticker symbol to load data for.
63-
start_timestamp: Optional start time for filtering data.
64-
end_timestamp: Optional end time for filtering data.
65-
66-
Yields:
67-
MarketDataEvent objects containing OHLCV data.
68-
69-
Raises:
70-
ValueError: If required columns are missing from the CSV.
71-
72-
"""
73-
logger.debug("Opening CSV file...")
74-
df = pl.read_csv(self.file_path, try_parse_dates=True)
75-
76-
required_cols = {
77-
self._col_timestamp,
78-
self._col_close,
79-
self._col_high,
80-
self._col_low,
81-
self._col_open,
82-
self._col_volume,
83-
}
84-
85-
missing_cols = required_cols.difference(df.columns)
86-
if missing_cols:
87-
raise ValueError(f"Missing columns: {missing_cols}")
88-
89-
# Convert date column to datetime if needed (polars parses as date by default)
90-
if df[self._col_timestamp].dtype == pl.Date:
91-
df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime))
92-
93-
# Filter by symbol using polars
94-
if self._col_symbol in df.columns:
95-
df = df.filter(pl.col(self._col_symbol) == symbol)
96-
97-
# Filter by timestamp bounds using polars
98-
if start_timestamp:
99-
df = df.filter(pl.col(self._col_timestamp) >= start_timestamp)
100-
if end_timestamp:
101-
df = df.filter(pl.col(self._col_timestamp) <= end_timestamp)
102-
103-
# Convert to dicts once after all filtering
104-
for row in df.iter_rows(named=True):
105-
event = MarketDataEvent(
106-
timestamp=row[self._col_timestamp],
107-
symbol=symbol,
108-
open=row[self._col_open],
109-
high=row[self._col_high],
110-
low=row[self._col_low],
111-
close=row[self._col_close],
112-
volume=row[self._col_volume],
113-
)
114-
yield event
45+
super().__init__(
46+
df_or_file_path=file_path,
47+
col_timestamp=col_timestamp,
48+
col_symbol=col_symbol,
49+
col_open=col_open,
50+
col_high=col_high,
51+
col_low=col_low,
52+
col_close=col_close,
53+
col_volume=col_volume,
54+
)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Polars data feed implementation."""
2+
3+
import logging
4+
from collections.abc import Generator
5+
from datetime import datetime
6+
from pathlib import Path
7+
8+
import polars as pl
9+
10+
from alphaflow import DataFeed
11+
from alphaflow.events.market_data_event import MarketDataEvent
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class PolarsDataFeed(DataFeed):
17+
"""Data feed that loads market data from Polars dataframes."""
18+
19+
def __init__(
20+
self,
21+
df_or_file_path: Path | str | pl.DataFrame | pl.LazyFrame,
22+
*,
23+
col_timestamp: str = "Date",
24+
col_symbol: str = "Symbol",
25+
col_open: str = "Open",
26+
col_high: str = "High",
27+
col_low: str = "Low",
28+
col_close: str = "Close",
29+
col_volume: str = "Volume",
30+
) -> None:
31+
"""Initialize the Polars data feed.
32+
33+
Args:
34+
df_or_file_path: Polars dataframe or path to the Polars dataframe containing market data.
35+
col_timestamp: Name of the timestamp column.
36+
col_symbol: Name of the symbol column.
37+
col_open: Name of the open price column.
38+
col_high: Name of the high price column.
39+
col_low: Name of the low price column.
40+
col_close: Name of the close price column.
41+
col_volume: Name of the volume column.
42+
43+
"""
44+
self.df_or_file_path = df_or_file_path
45+
self._col_timestamp = col_timestamp
46+
self._col_symbol = col_symbol
47+
self._col_open = col_open
48+
self._col_high = col_high
49+
self._col_low = col_low
50+
self._col_close = col_close
51+
self._col_volume = col_volume
52+
53+
def run(
54+
self,
55+
symbol: str,
56+
start_timestamp: datetime | None,
57+
end_timestamp: datetime | None,
58+
) -> Generator[MarketDataEvent, None, None]:
59+
"""Load and yield market data events from the Polars dataframe.
60+
61+
Args:
62+
symbol: The ticker symbol to load data for.
63+
start_timestamp: Optional start time for filtering data.
64+
end_timestamp: Optional end time for filtering data.
65+
66+
Yields:
67+
MarketDataEvent objects containing OHLCV data.
68+
69+
Raises:
70+
ValueError: If required columns are missing from the Polars dataframe.
71+
72+
"""
73+
if isinstance(self.df_or_file_path, (str, Path)):
74+
df_path = Path(self.df_or_file_path) if isinstance(self.df_or_file_path, str) else self.df_or_file_path
75+
if df_path.suffix in {".parquet", ".parq"}:
76+
df = pl.read_parquet(df_path)
77+
df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime))
78+
elif df_path.suffix == ".csv":
79+
df = pl.read_csv(df_path, try_parse_dates=True)
80+
else:
81+
raise ValueError(f"Unsupported file format: {df_path.suffix}")
82+
elif isinstance(self.df_or_file_path, pl.LazyFrame):
83+
df = self.df_or_file_path.collect()
84+
else:
85+
df = self.df_or_file_path
86+
87+
required_cols = {
88+
self._col_timestamp,
89+
self._col_close,
90+
self._col_high,
91+
self._col_low,
92+
self._col_open,
93+
self._col_volume,
94+
}
95+
96+
missing_cols = required_cols.difference(df.columns)
97+
if missing_cols:
98+
raise ValueError(f"Missing columns: {missing_cols}")
99+
100+
# Convert date column to datetime if needed (polars parses as date by default)
101+
if df[self._col_timestamp].dtype == pl.Date:
102+
df = df.with_columns(pl.col(self._col_timestamp).cast(pl.Datetime))
103+
104+
# Filter by symbol using polars
105+
if self._col_symbol in df.columns:
106+
df = df.filter(pl.col(self._col_symbol) == symbol)
107+
108+
# Filter by timestamp bounds using polars
109+
if start_timestamp:
110+
df = df.filter(pl.col(self._col_timestamp) >= start_timestamp)
111+
if end_timestamp:
112+
df = df.filter(pl.col(self._col_timestamp) <= end_timestamp)
113+
114+
# Convert to dicts once after all filtering
115+
for row in df.sort(by=self._col_timestamp).iter_rows(named=True):
116+
event = MarketDataEvent(
117+
timestamp=row[self._col_timestamp],
118+
symbol=symbol,
119+
open=row[self._col_open],
120+
high=row[self._col_high],
121+
low=row[self._col_low],
122+
close=row[self._col_close],
123+
volume=row[self._col_volume],
124+
)
125+
yield event

alphaflow/tests/test_alphaflow.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from alphaflow import AlphaFlow
99
from alphaflow.analyzers import DefaultAnalyzer
1010
from alphaflow.brokers import SimpleBroker
11-
from alphaflow.data_feeds import CSVDataFeed
11+
from alphaflow.data_feeds import PolarsDataFeed
1212
from alphaflow.strategies import BuyAndHoldStrategy
1313

1414

@@ -50,7 +50,7 @@ def test_alphaflow_add_equity() -> None:
5050
def test_alphaflow_set_data_feed() -> None:
5151
"""Test setting the data feed."""
5252
af = AlphaFlow()
53-
data_feed = CSVDataFeed("alphaflow/tests/data/AAPL.csv")
53+
data_feed = PolarsDataFeed("alphaflow/tests/data/AAPL.csv")
5454

5555
af.set_data_feed(data_feed)
5656

@@ -174,7 +174,7 @@ def test_alphaflow_set_backtest_end_timestamp_string() -> None:
174174
def test_alphaflow_get_timestamps() -> None:
175175
"""Test getting all timestamps from loaded data."""
176176
af = AlphaFlow()
177-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
177+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
178178
af.add_equity("AAPL")
179179
af.set_cash(10000)
180180
af.set_data_start_timestamp(datetime(1980, 12, 25))
@@ -192,7 +192,7 @@ def test_alphaflow_get_timestamps() -> None:
192192
def test_alphaflow_get_price() -> None:
193193
"""Test getting price for a symbol at a timestamp."""
194194
af = AlphaFlow()
195-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
195+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
196196
af.add_equity("AAPL")
197197
af.set_cash(10000)
198198
af.set_data_start_timestamp(datetime(1980, 12, 25))
@@ -207,7 +207,7 @@ def test_alphaflow_get_price() -> None:
207207
def test_alphaflow_get_price_raises_error_for_missing_data() -> None:
208208
"""Test get_price raises error when no data exists after timestamp."""
209209
af = AlphaFlow()
210-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
210+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
211211
af.add_equity("AAPL")
212212
af.set_cash(10000)
213213
af.set_data_start_timestamp(datetime(1980, 12, 25))
@@ -222,7 +222,7 @@ def test_alphaflow_get_price_raises_error_for_missing_data() -> None:
222222
def test_alphaflow_on_missing_price_warn(caplog: pytest.LogCaptureFixture) -> None:
223223
"""Test that on_missing_price='warn' logs a warning and returns 0.0."""
224224
af = AlphaFlow(on_missing_price="warn")
225-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
225+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
226226
af.add_equity("AAPL")
227227
af.set_cash(10000)
228228
af.set_data_start_timestamp(datetime(1980, 12, 25))
@@ -239,7 +239,7 @@ def test_alphaflow_on_missing_price_warn(caplog: pytest.LogCaptureFixture) -> No
239239
def test_alphaflow_on_missing_price_ignore() -> None:
240240
"""Test that on_missing_price='ignore' silently returns 0.0."""
241241
af = AlphaFlow(on_missing_price="ignore")
242-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
242+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
243243
af.add_equity("AAPL")
244244
af.set_cash(10000)
245245
af.set_data_start_timestamp(datetime(1980, 12, 25))
@@ -269,7 +269,7 @@ def test_alphaflow_run_raises_error_without_data_feed() -> None:
269269
def test_alphaflow_run_raises_error_for_live_trading() -> None:
270270
"""Test run raises error for live trading (not implemented)."""
271271
af = AlphaFlow()
272-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
272+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
273273
af.add_equity("AAPL")
274274
af.set_cash(10000)
275275

@@ -280,7 +280,7 @@ def test_alphaflow_run_raises_error_for_live_trading() -> None:
280280
def test_alphaflow_complete_backtest_flow() -> None:
281281
"""Test complete backtest flow with all components."""
282282
af = AlphaFlow()
283-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
283+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
284284
af.add_equity("AAPL")
285285
af.set_benchmark("AAPL")
286286
af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0))
@@ -307,7 +307,7 @@ def test_alphaflow_complete_backtest_flow() -> None:
307307
def test_simple_backtest() -> None:
308308
"""Test a simple buy-and-hold backtest with AAPL."""
309309
af = AlphaFlow()
310-
af.set_data_feed(CSVDataFeed("alphaflow/tests/data/AAPL.csv"))
310+
af.set_data_feed(PolarsDataFeed("alphaflow/tests/data/AAPL.csv"))
311311
af.add_equity("AAPL")
312312
af.add_strategy(BuyAndHoldStrategy(symbol="AAPL", target_weight=1.0))
313313
af.set_broker(SimpleBroker())

0 commit comments

Comments
 (0)