Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/roiextract/_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from roiextract.optimize import ctf_optimize_label


def sample_criterion(
lmbd, fwd, label, template, mode, reg, dist_fun, return_filter=False
):
sf, props = ctf_optimize_label(
fwd,
label,
template,
lmbd,
mode,
reg=reg,
initial="auto",
quantify=True,
)

if return_filter:
return sf

x_key = "hom" if mode == "homogeneity" else "sim"
dist = dist_fun(props[x_key], props["rat"])
props_parts = [f"{k}: {v:.4g}" for k, v in props.items()]
props_desc = ", ".join(props_parts)
print(f"sample_criterion | lambda={lmbd:.4g} | {props_desc} | dist={dist:.4g}")

return dist
8 changes: 6 additions & 2 deletions src/roiextract/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def from_inverse(
roi_method,
subject,
subjects_dir,
verbose=False,
):
src = fwd["src"]
ch_names = fwd["info"]["ch_names"]
mask = get_label_mask(label, src)
W = get_inverse_matrix(inv, fwd, inv_method, lambda2)
w_agg = get_aggregation_weights(roi_method, label, src, subject, subjects_dir)
with mne.use_log_level(verbose):
W = get_inverse_matrix(inv, fwd, inv_method, lambda2)
w_agg = get_aggregation_weights(
roi_method, label, src, subject, subjects_dir
)
w = w_agg @ W[mask, :]
return cls(
np.atleast_1d(np.squeeze(w)),
Expand Down
33 changes: 22 additions & 11 deletions src/roiextract/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
from copy import deepcopy

from roiextract.optimize import ctf_optimize_label
from roiextract.utils import _check_input, logger
from roiextract.utils import _check_input, logger, normalize_values


class OptimizationCurve:
def __init__(self, fwd, label, mode, template, sampling_limit=0.0001):
def __init__(self, fwd, label, mode, template, reg=0.001, sampling_limit=0.0001):
_check_input("mode", mode, ["homogeneity", "similarity"])
self.fwd = fwd
self.label = label
self.mode = mode
self.template = template
self.reg = reg
self.sampling_limit = sampling_limit
self._reset()

def _reset(self):
self.filters = None
self.lambdas = None
self.n_points = None
self._xs = None
self._ys = None

Expand Down Expand Up @@ -50,9 +52,13 @@ def sims(self):
def rats(self):
return self._ys

def plot_curve(self, ax):
ax.plot(self._xs, self._ys, "k-")
ax.scatter(self._xs, self._ys, c=self.lambdas)
def plot(self, ax, normalize=False):
xs, ys = self._xs, self._ys
if normalize:
xs, ys = normalize_values(xs, ys)

ax.plot(xs, ys, "k-")
ax.scatter(xs, ys, c=self.lambdas)
ax.set_aspect("equal")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
Expand All @@ -75,11 +81,17 @@ def sample_adaptive(self, dist_threshold):
return self

def sample_homogeneous(self, n_points):
lambdas = np.linspace(0, 1, num=n_points)

return self.sample_manually(lambdas)

def sample_manually(self, lambdas):
self._reset()
self.filters = []
self.lambdas = np.linspace(0, 1, num=n_points)
self._xs = np.zeros((n_points,))
self._ys = np.zeros((n_points,))
self.lambdas = lambdas
self.n_points = len(lambdas)
self._xs = np.zeros((self.n_points,))
self._ys = np.zeros((self.n_points,))

for i, lmbd in enumerate(self.lambdas):
result = self._sample(lmbd)
Expand Down Expand Up @@ -111,6 +123,7 @@ def _sample(self, lmbd, return_filter=False):
lmbd,
mode=self.mode,
initial="auto",
reg=self.reg,
quantify=True,
)
logger.debug(f"_sample | lambda={lmbd:.3g} | {props}")
Expand All @@ -122,9 +135,7 @@ def _sample(self, lmbd, return_filter=False):
return np.array([props[x_key], props["rat"]])

def _get_filters(self):
self.filters = [
self._sample(lmbd, return_filter=True) for lmbd in self.lambdas
]
self.filters = [self._sample(lmbd, return_filter=True) for lmbd in self.lambdas]

def _get_sampled_values(self, learner):
lambdas = np.array(list(learner.data.keys()))
Expand Down
19 changes: 19 additions & 0 deletions src/roiextract/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

from roiextract.filter import SpatialFilter
from roiextract.inspect import OptimizationCurve


def load_optimization_curve(filename):
oc = OptimizationCurve(None, None, "similarity", None)
with np.load(filename) as data:
oc.lambdas = data["lambdas"]
oc.filters = [
SpatialFilter(w=w, method_params=dict(lambda_=lmbd))
for w, lmbd in zip(data["filters"], oc.lambdas)
]
oc.n_points = len(oc.lambdas)
oc._ys = data["rats"]
oc._xs = data.get("homs", data["sims"])

return oc
14 changes: 4 additions & 10 deletions src/roiextract/quantify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ctf_ratio(w, L, mask, source_mask=None):
ctf_in = np.squeeze(w @ L_in)

# Adjusted value in the [0, 1] range
return norm(ctf_in) / norm(ctf_all)
return (norm(ctf_in) / norm(ctf_all)) ** 2


def ctf_similarity(w, L, w0, mask, source_mask=None):
Expand Down Expand Up @@ -128,13 +128,9 @@ def ctf_quantify(w, leadfield, mask, w0=None, P0=None, source_mask=None):
result = dict()
result["rat"] = ctf_ratio(w, leadfield, mask, source_mask=source_mask)
if w0 is not None:
result["sim"] = ctf_similarity(
w, leadfield, w0, mask, source_mask=source_mask
)
result["sim"] = ctf_similarity(w, leadfield, w0, mask, source_mask=source_mask)
if P0 is not None:
result["hom"] = ctf_homogeneity(
w, leadfield, P0, mask, source_mask=source_mask
)
result["hom"] = ctf_homogeneity(w, leadfield, P0, mask, source_mask=source_mask)

return result

Expand All @@ -153,9 +149,7 @@ def ctf_quantify_label(w, fwd, label, w0=None, P0=None, source_mask=None):
if P0 is not None:
P0 = resolve_template(P0, label, src)

return ctf_quantify(
w, leadfield, mask, w0=w0, P0=P0, source_mask=source_mask
)
return ctf_quantify(w, leadfield, mask, w0=w0, P0=P0, source_mask=source_mask)


def rec_quantify(w, cov_matrix, inverse, template, mask, source_mask=None):
Expand Down
131 changes: 131 additions & 0 deletions src/roiextract/suggest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import adaptive
import numpy as np

from copy import deepcopy
from functools import partial

from roiextract.inspect import OptimizationCurve
from roiextract.utils import _check_input, normalize_values
from roiextract._sample import sample_criterion


def theta(sims, rats, limits):
sims_norm, rats_norm = normalize_values(sims, rats, limits)

# Normalize to [0, 1] range, make theta=0 match lambda=0
return 1 - 2 * np.arctan(rats_norm / sims_norm) / np.pi


def theta_dist(sims, rats, target, limits):
return theta(sims, rats, limits) - target


class Suggester:
def __init__(self, fwd, label, mode, template, reg=0.001, sampling_limit=0.0001):
_check_input("mode", mode, ["homogeneity", "similarity"])
self.fwd = fwd
self.label = label
self.mode = mode
self.template = template
self.reg = reg
self.sampling_limit = sampling_limit
self._reset()

def _reset(self):
self.filters = None
self.lambdas = None
self.n_points = None
self._xs = None
self._ys = None

def copy(self):
return deepcopy(self)

@property
def homs(self):
if self.mode == "similarity":
raise ValueError(
f"The optimization mode is set to {self.mode}, so the "
f"homogeneity values are not available."
)

return self._xs

@property
def sims(self):
if self.mode == "homogeneity":
raise ValueError(
f"The optimization mode is set to {self.mode}, so the "
f"similarity values are not available."
)

return self._xs

@property
def rats(self):
return self._ys

@staticmethod
def get_sampled_values(learner):
lambdas = np.array(list(learner.data.keys()))
sort_idx = np.argsort(lambdas)
optim_curve = np.vstack(list(learner.data.values()))

lambdas = lambdas[sort_idx]
dist = optim_curve[sort_idx]

return lambdas, dist

def suggest_for_theta(self, target_theta, tol=0.001):
assert target_theta >= 0.0 and target_theta <= 1.0

oc = OptimizationCurve(self.fwd, self.label, self.mode, self.template, self.reg)
oc.sample_manually([0.0, 1.0])

_, _, limits = normalize_values(oc.sims, oc.rats, return_limits=True)

dist_fun = partial(theta_dist, target=target_theta, limits=limits)
loss_fun = partial(self._monotone_search_loss)
sample_fun = partial(
sample_criterion,
fwd=self.fwd,
label=self.label,
template=self.template,
mode=self.mode,
reg=self.reg,
dist_fun=dist_fun,
)

learner = adaptive.Learner1D(
sample_fun,
bounds=(0, 1),
loss_per_interval=loss_fun,
)
adaptive.BlockingRunner(learner, loss_goal=tol, ntasks=1)

lambdas, dist = self.get_sampled_values(learner)
best_lambda = lambdas[np.argmin(np.abs(dist))]

oc.sample_manually([best_lambda])

return best_lambda, oc.filters[0]

def _monotone_search_loss(self, xs, ys):
xs = np.array([x for x in xs if x is not None])
ys = np.array([y for y in ys if y is not None])

if ys.size < 2:
return 0

# TODO: take care of the sampling limit?
# lambda_dist = np.abs(xs[1] - xs[0])

# NOTE: If we sample a monotone function and the distance to the target
# has opposite signs at the ends of an interval, we need to sample
# another point in this interval. If the signs are the same, the
# extremum does not belong to the interval, so we don't need to sample
# this interval at all, setting the loss to 0.
if ys[0] * ys[1] > 0:
return 0

return np.abs(ys).min()
17 changes: 17 additions & 0 deletions src/roiextract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,20 @@ def get_aggregation_weights(method, label, src, subject, subjects_dir):
subject=subject, restrict_vertices=src, subjects_dir=subjects_dir
)
return (label_vertices == centroid_idx).astype(int)


def normalize_values(xs, ys, limits=None, return_limits=False):
if limits is not None:
x_min, x_max, y_min, y_max = limits
else:
x_min, x_max = xs.min(), xs.max()
y_min, y_max = ys.min(), ys.max()
limits = (x_min, x_max, y_min, y_max)

xs_norm = (xs - x_min) / (x_max - x_min)
ys_norm = (ys - y_min) / (y_max - y_min)

if not return_limits:
return xs_norm, ys_norm

return xs_norm, ys_norm, limits
8 changes: 4 additions & 4 deletions tests/test_quantify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def create_data():
@pytest.mark.parametrize(
"mask,source_mask,expected_ratio",
[
[np.array([True, False, False, False]), None, np.sqrt(0.1)],
[np.array([True, True, False, False]), None, np.sqrt(0.2)],
[np.array([False, True, True, False]), None, np.sqrt(0.5)],
[np.array([True, False, False, False]), None, 0.1],
[np.array([True, True, False, False]), None, 0.2],
[np.array([False, True, True, False]), None, 0.5],
[np.array([True, True, True, True]), None, 1.0],
[
np.array([False, True, True, False]),
np.array([False, True, False, True]),
np.sqrt(0.2),
0.2,
],
],
)
Expand Down