Skip to content

Commit 97de58e

Browse files
rterbushmementum
authored andcommitted
Update example system to use new BT Indicators. (mementum#273)
1 parent f0b0998 commit 97de58e

File tree

1 file changed

+12
-72
lines changed

1 file changed

+12
-72
lines changed

contrib/samples/pair-trading/pair-trading.py

+12-72
Original file line numberDiff line numberDiff line change
@@ -15,68 +15,6 @@
1515
import backtrader as bt
1616
import backtrader.feeds as btfeeds
1717
import backtrader.indicators as btind
18-
import statsmodels.api as sm
19-
import pandas as pd
20-
21-
class OLS_Transformation(btind.PeriodN):
22-
lines = (('slope'),('intercept'),('spread'),('spread_mean'),('spread_std'),('zscore'),)
23-
params = (('period', 10),)
24-
25-
def __init__(self):
26-
self.addminperiod(self.p.period)
27-
28-
def next(self):
29-
p0 = pd.Series(self.data0.get(size=self.p.period))
30-
p1 = sm.add_constant(pd.Series(self.data1.get(size=self.p.period)),prepend=True)
31-
slope, intercept = sm.OLS(p0,p1).fit().params
32-
self.lines.slope[0] = slope
33-
self.lines.intercept[0] = intercept
34-
self.lines.spread[0] = self.data0.close[0] - (slope * self.data1.close[0] + intercept)
35-
self.lines.spread_mean[0] = pd.Series(self.lines.spread.get(size=self.p.period)).mean()
36-
self.lines.spread_std[0] = pd.Series(self.lines.spread.get(size=self.p.period)).std()
37-
self.lines.zscore[0] = (self.lines.spread[0] - self.lines.spread_mean[0])/self.lines.spread_std[0]
38-
39-
40-
class OLS_Beta(bt.indicators.PeriodN):
41-
_mindatas = 2 # ensure at least 2 data feeds are passed
42-
lines = (('beta'),)
43-
params = (('period', 10),)
44-
45-
def next(self):
46-
y, x = (pd.Series(d.get(size=self.p.period)) for d in (self.data0, self.data1))
47-
r_beta = pd.ols(y=y, x=x, window_type='full_sample')
48-
self.lines.beta[0] = r_beta.beta['x']
49-
50-
class Spread(bt.indicators.PeriodN):
51-
_mindatas = 2 # ensure at least 2 data feeds are passed
52-
lines = (('spread'),)
53-
params = (('period', 10),)
54-
55-
def next(self):
56-
y, x = (pd.Series(d.get(size=self.p.period)) for d in (self.data0, self.data1))
57-
r_beta = pd.ols(y=y, x=x, window_type='full_sample')
58-
self.lines.spread[0] = self.data1[0] - r_beta.beta['x'] * self.data0[0]
59-
60-
class ZScore(bt.indicators.PeriodN):
61-
_mindatas = 2 # ensure at least 2 data feeds are passed
62-
lines = (('zscore'),('upper'),('lower'),('up_medium'),('low_medium'),)
63-
params = (('period', 10),('upper',2),('lower',-2),('up_medium',0.5),('low_medium',-0.5),)
64-
65-
def __init__(self):
66-
self.spread = Spread(self.data0, self.data1, period=self.p.period, plot=False)
67-
self.spread_mean = btind.MovAv.SMA(self.spread, period=self.p.period, plot=False)
68-
self.spread_std = btind.StandardDeviation(self.spread, period=self.p.period, plot=False)
69-
70-
def next(self):
71-
# Step 1: Construct ZScore
72-
if self.spread_std[0]>0:
73-
self.lines.zscore[0] = (self.spread[0] - self.spread_mean[0])/self.spread_std[0]
74-
if self.spread_std[0]==0:
75-
self.lines.zscore[0] = 0
76-
self.lines.upper[0]=self.p.upper
77-
self.lines.lower[0] = self.p.lower
78-
self.lines.up_medium[0]=self.p.up_medium
79-
self.lines.low_medium[0] = self.p.low_medium
8018

8119

8220
class PairTradingStrategy(bt.Strategy):
@@ -86,12 +24,12 @@ class PairTradingStrategy(bt.Strategy):
8624
qty1=0,
8725
qty2=0,
8826
printout=True,
89-
upper = 2.1,
90-
lower = -2.1,
91-
up_medium = 0.5,
92-
low_medium = -0.5,
93-
status = 0,
94-
portfolio_value = 10000,
27+
upper=2.1,
28+
lower=-2.1,
29+
up_medium=0.5,
30+
low_medium=-0.5,
31+
status=0,
32+
portfolio_value=10000,
9533
)
9634

9735
def log(self, txt, dt=None):
@@ -132,11 +70,14 @@ def __init__(self):
13270
self.portfolio_value = self.p.portfolio_value
13371

13472
# Signals performed with PD.OLS :
135-
self.zscore = ZScore(self.data0,self.data1, period=self.p.period, upper=self.p.upper, lower=self.p.lower, up_medium=self.p.up_medium, low_medium=self.p.low_medium)
73+
self.transform = btind.OLS_TransformationN(self.data0, self.data1,
74+
period=self.p.period)
75+
self.zscore = self.transform.zscore
13676

13777
# Checking signals built with StatsModel.API :
138-
# self.ols_transfo = OLS_Transformation(self.data0, self.data1, period=self.p.period, plot=True)
139-
78+
# self.ols_transfo = btind.OLS_Transformation(self.data0, self.data1,
79+
# period=self.p.period,
80+
# plot=True)
14081

14182
def next(self):
14283

@@ -156,7 +97,6 @@ def next(self):
15697
print('status is', self.status)
15798
print('zscore is', self.zscore[0])
15899

159-
160100
# Step 2: Check conditions for SHORT & place the order
161101
# Checking the condition for SHORT
162102
if (self.zscore[0] > self.upper_limit) and (self.status != 1):

0 commit comments

Comments
 (0)