Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/roiextract/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
from scipy.linalg import eig


def _ged_matrices_ratio(leadfield, mask, source_cov=None):
L_in = leadfield[:, mask]

A = L_in @ L_in.T
B = leadfield @ leadfield.T

return A, B


def _ctf_optimize_ratio(leadfield, mask, reg, source_cov=None):
# Get the GED matrices and apply regularization
A, B = _ged_matrices_ratio(leadfield, mask, source_cov=None)
A_reg = A + reg * np.trace(A) * np.eye(*A.shape) / A.shape[0]
B_reg = B + reg * np.trace(B) * np.eye(*B.shape) / B.shape[0]

# Get the eigenvector that corresponds to the smallest eigenvalue
[eigvals, eigvecs] = eig(A_reg, B_reg)
w = eigvecs[:, eigvals.argmax()]

return w


def ctf_optimize_ratio_similarity(
leadfield, template, mask, lambda_, regA=0.001, regB=0.001
):
Expand Down
20 changes: 19 additions & 1 deletion src/roiextract/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from functools import partial

from .analytic import ctf_optimize_ratio_similarity
from .analytic import (
_ctf_optimize_ratio,
ctf_optimize_ratio_similarity
)
from .filter import SpatialFilter
from .numerical import ctf_optimize_ratio_homogeneity
from .quantify import ctf_quantify
Expand Down Expand Up @@ -210,3 +213,18 @@ def ctf_optimize_label(
ch_names=ch_names,
name=name,
)


def ctf_optimize_ratio(fwd, label, reg=0.001):
leadfield = prepare_leadfield(fwd)
mask = prepare_label_mask(label, fwd)

w = _ctf_optimize_ratio(leadfield, mask, reg)
return SpatialFilter(
w=w,
method="ctf_optimize_ratio",
method_params=dict(reg=reg),
ch_names=fwd.ch_names,
name=label.name,
)

45 changes: 45 additions & 0 deletions src/roiextract/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import mne
import numpy as np

from mne.label import label_sign_flip

from ctfopt.roiextract.filter import SpatialFilter
from ctfopt.roiextract.utils import get_label_mask, _check_input


def prepare_filter(sf):
if not isinstance(sf, SpatialFilter):
return sf

return sf.w


def prepare_leadfield(fwd):
if not isinstance(fwd, mne.Forward):
return fwd

# NOTE: fixed orientations only
return fwd["sol"]["data"]


def prepare_label_mask(label, fwd=None):
if not isinstance(label, mne.Label | mne.BiHemiLabel):
return label

assert fwd is not None and isinstance(fwd, mne.Forward)
return get_label_mask(label, fwd["src"])


def prepare_template(template, label=None, fwd=None):
if isinstance(template, str):
_check_input("template", template, ["mean_flip", "mean"])
assert fwd is not None and isinstance(fwd, mne.Forward)
assert label is not None and isinstance(label, mne.Label | mne.BiHemiLabel)
signflip = label_sign_flip(label, fwd["src"])[np.newaxis, :]

if template == "mean_flip":
return signflip
if template == "mean":
return np.ones((1, signflip.size))

return template
Loading