diff --git a/src/roiextract/analytic.py b/src/roiextract/analytic.py index 13f87e2..3f11aa1 100644 --- a/src/roiextract/analytic.py +++ b/src/roiextract/analytic.py @@ -4,6 +4,28 @@ from scipy.linalg import eig +def _ged_matrices_ratio(leadfield, mask, source_cov=None): + L_in = leadfield[:, mask] + + A = L_in @ L_in.T + B = leadfield @ leadfield.T + + return A, B + + +def _ctf_optimize_ratio(leadfield, mask, reg, source_cov=None): + # Get the GED matrices and apply regularization + A, B = _ged_matrices_ratio(leadfield, mask, source_cov=None) + A_reg = A + reg * np.trace(A) * np.eye(*A.shape) / A.shape[0] + B_reg = B + reg * np.trace(B) * np.eye(*B.shape) / B.shape[0] + + # Get the eigenvector that corresponds to the smallest eigenvalue + [eigvals, eigvecs] = eig(A_reg, B_reg) + w = eigvecs[:, eigvals.argmax()] + + return w + + def ctf_optimize_ratio_similarity( leadfield, template, mask, lambda_, regA=0.001, regB=0.001 ): diff --git a/src/roiextract/optimize.py b/src/roiextract/optimize.py index 55958fa..b77d8bc 100644 --- a/src/roiextract/optimize.py +++ b/src/roiextract/optimize.py @@ -2,7 +2,10 @@ from functools import partial -from .analytic import ctf_optimize_ratio_similarity +from .analytic import ( + _ctf_optimize_ratio, + ctf_optimize_ratio_similarity +) from .filter import SpatialFilter from .numerical import ctf_optimize_ratio_homogeneity from .quantify import ctf_quantify @@ -210,3 +213,18 @@ def ctf_optimize_label( ch_names=ch_names, name=name, ) + + +def ctf_optimize_ratio(fwd, label, reg=0.001): + leadfield = prepare_leadfield(fwd) + mask = prepare_label_mask(label, fwd) + + w = _ctf_optimize_ratio(leadfield, mask, reg) + return SpatialFilter( + w=w, + method="ctf_optimize_ratio", + method_params=dict(reg=reg), + ch_names=fwd.ch_names, + name=label.name, + ) + diff --git a/src/roiextract/prepare.py b/src/roiextract/prepare.py new file mode 100644 index 0000000..78a2b00 --- /dev/null +++ b/src/roiextract/prepare.py @@ -0,0 +1,45 @@ +import mne +import numpy as np + +from mne.label import label_sign_flip + +from ctfopt.roiextract.filter import SpatialFilter +from ctfopt.roiextract.utils import get_label_mask, _check_input + + +def prepare_filter(sf): + if not isinstance(sf, SpatialFilter): + return sf + + return sf.w + + +def prepare_leadfield(fwd): + if not isinstance(fwd, mne.Forward): + return fwd + + # NOTE: fixed orientations only + return fwd["sol"]["data"] + + +def prepare_label_mask(label, fwd=None): + if not isinstance(label, mne.Label | mne.BiHemiLabel): + return label + + assert fwd is not None and isinstance(fwd, mne.Forward) + return get_label_mask(label, fwd["src"]) + + +def prepare_template(template, label=None, fwd=None): + if isinstance(template, str): + _check_input("template", template, ["mean_flip", "mean"]) + assert fwd is not None and isinstance(fwd, mne.Forward) + assert label is not None and isinstance(label, mne.Label | mne.BiHemiLabel) + signflip = label_sign_flip(label, fwd["src"])[np.newaxis, :] + + if template == "mean_flip": + return signflip + if template == "mean": + return np.ones((1, signflip.size)) + + return template