Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"emcee>=3.1.0",
"hciplot>=0.2.4",
"matplotlib>=3.7.0",
"munch>=3.0.0",
"nestle>=0.2.0",
"numpy>=1.21.2",
"pandas>=1.3.3",
Expand Down
100 changes: 48 additions & 52 deletions src/vip_hci/metrics/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
from hciplot import plot_frames
from scipy import stats
from photutils.segmentation import detect_sources
from munch import Munch
from ..config import time_ini, timing, Progressbar
from ..fm import cube_inject_companions
from ..psfsub.svd import SVDecomposer
from ..var import frame_center, get_annulus_segments, get_circle

# TODO: remove the munch dependency


class EvalRoc(object):
"""
Expand Down Expand Up @@ -68,7 +65,7 @@ def add_algo(self, name, algo, color, symbol, thresholds):
thresholds : list of lists

"""
self.methods.append(Munch(algo=algo, name=name, color=color,
self.methods.append(dict(algo=algo, name=name, color=color,
symbol=symbol, thresholds=thresholds))

def inject_and_postprocess(self, patch_size, cevr=0.9,
Expand Down Expand Up @@ -97,11 +94,11 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
print("{}% of CEVR with {} PCs".format(cevr, self.optpcs))

# for m in methods:
# if hasattr(m, "ncomp") and m.ncomp is None: # PCA
# m.ncomp = self.optpcs
# if m.get("ncomp", object()) is None: # PCA
# m["ncomp"] = self.optpcs
#
# if hasattr(m, "rank") and m.rank is None: # LLSG
# m.rank = self.optpcs
# if m.get("rank", object()) is None: # LLSG
# m["rank"] = self.optpcs

#
# ------> this should be moved inside the HCIPostProcAlgo classes!
Expand Down Expand Up @@ -135,8 +132,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
self.thetas.append(theta)

for m in self.methods:
m.frames = []
m.probmaps = []
m["frames"] = []
m["probmaps"] = []

self.list_xy = []

Expand All @@ -157,7 +154,7 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
# TODO: this is not elegant at all.
# shallow copy. Should not copy e.g. the cube in memory,
# just reference it.
algo = copy.copy(m.algo)
algo = copy.copy(m["algo"])
_dataset = copy.copy(self.dataset)
_dataset.cube = cufc

Expand All @@ -169,8 +166,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
algo.run(dataset=_dataset, verbose=False)
algo.make_snrmap(approximated=True, nproc=nproc, verbose=False)

m.frames.append(algo.frame_final)
m.probmaps.append(algo.snr_map)
m["frames"].append(algo.frame_final)
m["probmaps"].append(algo.snr_map)

timing(starttime)

Expand All @@ -192,22 +189,22 @@ def compute_tpr_fps(self, **kwargs):
starttime = time_ini()

for m in self.methods:
m.detections = []
m.fps = []
m.bmaps = []
m["detections"] = []
m["fps"] = []
m["bmaps"] = []

print('Evaluating injections:')
for i in Progressbar(range(self.n_injections)):
x, y = self.list_xy[i]

for m in self.methods:
dets, fps, bmaps = compute_binary_map(
m.probmaps[i], m.thresholds, fwhm=self.dataset.fwhm,
m["probmaps"][i], m["thresholds"], fwhm=self.dataset.fwhm,
injections=(x, y), **kwargs
)
m.detections.append(dets)
m.fps.append(fps)
m.bmaps.append(bmaps)
m["detections"].append(dets)
m["fps"].append(fps)
m["bmaps"].append(bmaps)

timing(starttime)

Expand Down Expand Up @@ -245,9 +242,9 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,

if vmax == 'max':
# TODO: document this feature.
vmax = np.concatenate([m.frames[i] for m in self.methods if
hasattr(m, "frames") and
len(m.frames) >= i]).max()/2
vmax = np.concatenate([m["frames"][i] for m in self.methods if
"frames" in m and
len(m["frames"]) >= i]).max()/2

# print information
print('X,Y: {}'.format(self.list_xy[i]))
Expand All @@ -258,33 +255,32 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,
if plot_type in [1, "horiz"]:
for m in self.methods:
print('detection state: {} | false postives: {}'.format(
m.detections[i][thr], m.fps[i][thr]))
labels = ('{} frame'.format(m.name), '{} S/Nmap'.format(m.name),
'Thresholded at {:.1f}'.format(m.thresholds[thr]))
plot_frames((m.frames[i] if len(m.frames) >= i else
np.zeros((2, 2)), m.probmaps[i], m.bmaps[i][thr]),
m["detections"][i][thr], m["fps"][i][thr]))
labels = (f"{m['name']} frame", f"{m['name']} S/Nmap",
f"Thresholded at {m['thresholds'][thr]:.1f}")
plot_frames((m["frames"][i] if len(m["frames"]) >= i else
np.zeros((2, 2)), m["probmaps"][i], m["bmaps"][i][thr]),
label=labels, dpi=dpi, horsp=0.2, axis=axis,
grid=grid, cmap=['viridis', 'viridis', 'gray'])

elif plot_type in [2, "vert"]:
labels = tuple('{} frame'.format(m.name) for m in self.methods if
hasattr(m, "frames") and len(m.frames) >= i)
plot_frames(tuple(m.frames[i] for m in self.methods if
hasattr(m, "frames") and len(m.frames) >= i),
labels = tuple(f"{m['name']} frame" for m in self.methods if
"frames" in m and len(m["frames"]) >= i)
plot_frames(tuple(m["frames"][i] for m in self.methods if
"frames" in m and len(m["frames"]) >= i),
dpi=dpi, label=labels, vmax=vmax, vmin=vmin, axis=axis,
grid=grid)

plot_frames(tuple(m.probmaps[i] for m in self.methods), dpi=dpi,
label=tuple(['{} S/Nmap'.format(m.name) for m in
plot_frames(tuple(m["probmaps"][i] for m in self.methods), dpi=dpi,
label=tuple([f"{m['name']} S/Nmap" for m in
self.methods]), axis=axis, grid=grid)

for m in self.methods:
msg = '{} detection: {}, FPs: {}'
print(msg.format(m.name, m.detections[i][thr], m.fps[i][thr]))
print(f"{m['name']} detection: {m['detections'][i][thr]}, FPs: {m['fps'][i][thr]}")

labels = tuple('Thresholded at {:.1f}'.format(m.thresholds[thr])
labels = tuple(f"Thresholded at {m['thresholds'][thr]:.1f}"
for m in self.methods)
plot_frames(tuple(m.bmaps[i][thr] for m in self.methods),
plot_frames(tuple(m["bmaps"][i][thr] for m in self.methods),
dpi=dpi, label=labels, axis=axis, grid=grid,
colorbar=False, cmap='bone')
else:
Expand Down Expand Up @@ -342,40 +338,40 @@ def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None,
# "SODIRF": dict(color="#9467bd", symbol="s"),
# "SODINN": dict(color="#1f77b4", symbol="p"),
# "SODINN-pw": dict(color="#1f77b4", symbol="p")
# } # maps m.name to plot style
# } # maps m["name"] to plot style

for i, m in enumerate(self.methods):

if not hasattr(m, "detections") or not hasattr(m, "fps"):
raise AttributeError("method #{} has no detections/fps. Run"
"`compute_tpr_fps` first.".format(i))

m.tpr = np.zeros((n_thresholds))
m.mean_fps = np.zeros((n_thresholds))
m["tpr"] = np.zeros(n_thresholds)
m["mean_fps"] = np.zeros(n_thresholds)

for j in range(n_thresholds):
m.tpr[j] = np.asarray(m.detections)[:, j].tolist().count(1) / \
m["tpr"][j] = np.asarray(m["detections"])[:, j].tolist().count(1) / \
self.n_injections
m.mean_fps[j] = np.asarray(m.fps)[:, j].mean()
m["mean_fps"][j] = np.asarray(m["fps"])[:, j].mean()

plt.plot(m.mean_fps, m.tpr, '--', color=m.color, **linekw)
plt.plot(m.mean_fps, m.tpr, m.symbol, label=m.name, color=m.color,
plt.plot(m["mean_fps"], m["tpr"], '--', color=m["color"], **linekw)
plt.plot(m["mean_fps"], m["tpr"], m["symbol"], label=m["name"], color=m["color"],
**markerkw)

if show_data_labels:
if label_skip_one[i]:
lab_x = m.mean_fps[1::2]
lab_y = m.tpr[1::2]
thr = m.thresholds[1::2]
lab_x = m["mean_fps"][1::2]
lab_y = m["tpr"][1::2]
thr = m["thresholds"][1::2]
else:
lab_x = m.mean_fps
lab_y = m.tpr
thr = m.thresholds
lab_x = m["mean_fps"]
lab_y = m["tpr"]
thr = m["thresholds"]

for i, xy in enumerate(zip(lab_x + label_gap[0],
lab_y + label_gap[1])):
labels.append(ax.annotate('{:.2f}'.format(thr[i]),
xy=xy, xycoords='data', color=m.color,
xy=xy, xycoords='data', color=m["color"],
**labelskw))
# TODO: reverse order of `self.methods` for better annot.
# z-index?
Expand Down