Skip to content

Commit 2cdf79a

Browse files
committed
ADD RoseCDL solver
1 parent 3a0ab4c commit 2cdf79a

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

solvers/rosecdl.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from benchopt import safe_import_context, BaseSolver
2+
3+
with safe_import_context() as import_ctx:
4+
from rosecdl.rosecdl import RoseCDL
5+
import torch
6+
7+
8+
class Solver(BaseSolver):
9+
name = "RoseCDL"
10+
11+
install_cmd = "conda"
12+
requirements = ["pip:rosecdl"]
13+
14+
parameters = {
15+
"n_components": [1],
16+
"n_channels": [1],
17+
"kernel_size": [64],
18+
"lmbd": [0.8],
19+
"scale_lmbd": [False],
20+
"epochs": [5, 50],
21+
"max_batch": [None],
22+
"mini_batch_size": [600],
23+
"sample_window": [10_000],
24+
"optimizer": ["adam"],
25+
"n_iterations": [10, 90],
26+
"window": [False],
27+
"outliers_kwargs": [
28+
{
29+
"method": "mad",
30+
"alpha": 3.5,
31+
"moving_average": None,
32+
"union_channels": True,
33+
"opening_window": True,
34+
},
35+
],
36+
}
37+
38+
sampling_strategy = "run_once"
39+
40+
def set_objective(self, X_train, y_test, X_test):
41+
self.device = torch.device(
42+
"cuda" if torch.cuda.is_available() else "cpu"
43+
)
44+
45+
# We receive data in shape (n_samples, n_features)
46+
# We want to reshape it to (n_recordings, n_features, n_samples)
47+
X_train = X_train.reshape(1, X_train.shape[1], X_train.shape[0])
48+
X_test = X_test.reshape(1, X_test.shape[1], X_test.shape[0])
49+
self.y_test = y_test
50+
51+
self.X_train = torch.tensor(
52+
X_train, dtype=torch.float32, device=self.device)
53+
self.X_test = X_test
54+
55+
self.clf = RoseCDL(
56+
n_components=self.n_components,
57+
n_channels=self.n_channels,
58+
kernel_size=self.kernel_size,
59+
lmbd=self.lmbd,
60+
scale_lmbd=self.scale_lmbd,
61+
epochs=self.epochs,
62+
max_batch=self.max_batch,
63+
mini_batch_size=self.mini_batch_size,
64+
sample_window=self.sample_window,
65+
optimizer=self.optimizer,
66+
n_iterations=self.n_iterations,
67+
window=self.window,
68+
device=self.device,
69+
outliers_kwargs=self.outliers_kwargs,
70+
)
71+
72+
def run(self, _):
73+
self.clf.fit(self.X_train)
74+
self.y_pred = self.clf.get_outlier_mask(self.X_test)
75+
76+
def get_result(self):
77+
return dict(y_hat=self.y_pred)

0 commit comments

Comments
 (0)