|
| 1 | +from benchopt import safe_import_context, BaseSolver |
| 2 | + |
| 3 | +with safe_import_context() as import_ctx: |
| 4 | + from rosecdl.rosecdl import RoseCDL |
| 5 | + import torch |
| 6 | + |
| 7 | + |
| 8 | +class Solver(BaseSolver): |
| 9 | + name = "RoseCDL" |
| 10 | + |
| 11 | + install_cmd = "conda" |
| 12 | + requirements = ["pip:rosecdl"] |
| 13 | + |
| 14 | + parameters = { |
| 15 | + "n_components": [1], |
| 16 | + "n_channels": [1], |
| 17 | + "kernel_size": [64], |
| 18 | + "lmbd": [0.8], |
| 19 | + "scale_lmbd": [False], |
| 20 | + "epochs": [5, 50], |
| 21 | + "max_batch": [None], |
| 22 | + "mini_batch_size": [600], |
| 23 | + "sample_window": [10_000], |
| 24 | + "optimizer": ["adam"], |
| 25 | + "n_iterations": [10, 90], |
| 26 | + "window": [False], |
| 27 | + "outliers_kwargs": [ |
| 28 | + { |
| 29 | + "method": "mad", |
| 30 | + "alpha": 3.5, |
| 31 | + "moving_average": None, |
| 32 | + "union_channels": True, |
| 33 | + "opening_window": True, |
| 34 | + }, |
| 35 | + ], |
| 36 | + } |
| 37 | + |
| 38 | + sampling_strategy = "run_once" |
| 39 | + |
| 40 | + def set_objective(self, X_train, y_test, X_test): |
| 41 | + self.device = torch.device( |
| 42 | + "cuda" if torch.cuda.is_available() else "cpu" |
| 43 | + ) |
| 44 | + |
| 45 | + # We receive data in shape (n_samples, n_features) |
| 46 | + # We want to reshape it to (n_recordings, n_features, n_samples) |
| 47 | + X_train = X_train.reshape(1, X_train.shape[1], X_train.shape[0]) |
| 48 | + X_test = X_test.reshape(1, X_test.shape[1], X_test.shape[0]) |
| 49 | + self.y_test = y_test |
| 50 | + |
| 51 | + self.X_train = torch.tensor( |
| 52 | + X_train, dtype=torch.float32, device=self.device) |
| 53 | + self.X_test = X_test |
| 54 | + |
| 55 | + self.clf = RoseCDL( |
| 56 | + n_components=self.n_components, |
| 57 | + n_channels=self.n_channels, |
| 58 | + kernel_size=self.kernel_size, |
| 59 | + lmbd=self.lmbd, |
| 60 | + scale_lmbd=self.scale_lmbd, |
| 61 | + epochs=self.epochs, |
| 62 | + max_batch=self.max_batch, |
| 63 | + mini_batch_size=self.mini_batch_size, |
| 64 | + sample_window=self.sample_window, |
| 65 | + optimizer=self.optimizer, |
| 66 | + n_iterations=self.n_iterations, |
| 67 | + window=self.window, |
| 68 | + device=self.device, |
| 69 | + outliers_kwargs=self.outliers_kwargs, |
| 70 | + ) |
| 71 | + |
| 72 | + def run(self, _): |
| 73 | + self.clf.fit(self.X_train) |
| 74 | + self.y_pred = self.clf.get_outlier_mask(self.X_test) |
| 75 | + |
| 76 | + def get_result(self): |
| 77 | + return dict(y_hat=self.y_pred) |
0 commit comments