diff --git a/pyproject.toml b/pyproject.toml index 6477b1e7..72af31e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "emcee>=3.1.0", "hciplot>=0.2.4", "matplotlib>=3.7.0", - "munch>=3.0.0", "nestle>=0.2.0", "numpy>=1.21.2", "pandas>=1.3.3", diff --git a/src/vip_hci/metrics/roc.py b/src/vip_hci/metrics/roc.py index dc15fe8e..62d6f434 100644 --- a/src/vip_hci/metrics/roc.py +++ b/src/vip_hci/metrics/roc.py @@ -10,14 +10,11 @@ from hciplot import plot_frames from scipy import stats from photutils.segmentation import detect_sources -from munch import Munch from ..config import time_ini, timing, Progressbar from ..fm import cube_inject_companions from ..psfsub.svd import SVDecomposer from ..var import frame_center, get_annulus_segments, get_circle -# TODO: remove the munch dependency - class EvalRoc(object): """ @@ -68,7 +65,7 @@ def add_algo(self, name, algo, color, symbol, thresholds): thresholds : list of lists """ - self.methods.append(Munch(algo=algo, name=name, color=color, + self.methods.append(dict(algo=algo, name=name, color=color, symbol=symbol, thresholds=thresholds)) def inject_and_postprocess(self, patch_size, cevr=0.9, @@ -97,11 +94,11 @@ def inject_and_postprocess(self, patch_size, cevr=0.9, print("{}% of CEVR with {} PCs".format(cevr, self.optpcs)) # for m in methods: - # if hasattr(m, "ncomp") and m.ncomp is None: # PCA - # m.ncomp = self.optpcs + # if m.get("ncomp", object()) is None: # PCA + # m["ncomp"] = self.optpcs # - # if hasattr(m, "rank") and m.rank is None: # LLSG - # m.rank = self.optpcs + # if m.get("rank", object()) is None: # LLSG + # m["rank"] = self.optpcs # # ------> this should be moved inside the HCIPostProcAlgo classes! @@ -135,8 +132,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9, self.thetas.append(theta) for m in self.methods: - m.frames = [] - m.probmaps = [] + m["frames"] = [] + m["probmaps"] = [] self.list_xy = [] @@ -157,7 +154,7 @@ def inject_and_postprocess(self, patch_size, cevr=0.9, # TODO: this is not elegant at all. # shallow copy. Should not copy e.g. the cube in memory, # just reference it. - algo = copy.copy(m.algo) + algo = copy.copy(m["algo"]) _dataset = copy.copy(self.dataset) _dataset.cube = cufc @@ -169,8 +166,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9, algo.run(dataset=_dataset, verbose=False) algo.make_snrmap(approximated=True, nproc=nproc, verbose=False) - m.frames.append(algo.frame_final) - m.probmaps.append(algo.snr_map) + m["frames"].append(algo.frame_final) + m["probmaps"].append(algo.snr_map) timing(starttime) @@ -192,9 +189,9 @@ def compute_tpr_fps(self, **kwargs): starttime = time_ini() for m in self.methods: - m.detections = [] - m.fps = [] - m.bmaps = [] + m["detections"] = [] + m["fps"] = [] + m["bmaps"] = [] print('Evaluating injections:') for i in Progressbar(range(self.n_injections)): @@ -202,12 +199,12 @@ def compute_tpr_fps(self, **kwargs): for m in self.methods: dets, fps, bmaps = compute_binary_map( - m.probmaps[i], m.thresholds, fwhm=self.dataset.fwhm, + m["probmaps"][i], m["thresholds"], fwhm=self.dataset.fwhm, injections=(x, y), **kwargs ) - m.detections.append(dets) - m.fps.append(fps) - m.bmaps.append(bmaps) + m["detections"].append(dets) + m["fps"].append(fps) + m["bmaps"].append(bmaps) timing(starttime) @@ -245,9 +242,9 @@ def plot_detmaps(self, i=None, thr=9, dpi=100, if vmax == 'max': # TODO: document this feature. - vmax = np.concatenate([m.frames[i] for m in self.methods if - hasattr(m, "frames") and - len(m.frames) >= i]).max()/2 + vmax = np.concatenate([m["frames"][i] for m in self.methods if + "frames" in m and + len(m["frames"]) >= i]).max()/2 # print information print('X,Y: {}'.format(self.list_xy[i])) @@ -258,33 +255,32 @@ def plot_detmaps(self, i=None, thr=9, dpi=100, if plot_type in [1, "horiz"]: for m in self.methods: print('detection state: {} | false postives: {}'.format( - m.detections[i][thr], m.fps[i][thr])) - labels = ('{} frame'.format(m.name), '{} S/Nmap'.format(m.name), - 'Thresholded at {:.1f}'.format(m.thresholds[thr])) - plot_frames((m.frames[i] if len(m.frames) >= i else - np.zeros((2, 2)), m.probmaps[i], m.bmaps[i][thr]), + m["detections"][i][thr], m["fps"][i][thr])) + labels = (f"{m['name']} frame", f"{m['name']} S/Nmap", + f"Thresholded at {m['thresholds'][thr]:.1f}") + plot_frames((m["frames"][i] if len(m["frames"]) >= i else + np.zeros((2, 2)), m["probmaps"][i], m["bmaps"][i][thr]), label=labels, dpi=dpi, horsp=0.2, axis=axis, grid=grid, cmap=['viridis', 'viridis', 'gray']) elif plot_type in [2, "vert"]: - labels = tuple('{} frame'.format(m.name) for m in self.methods if - hasattr(m, "frames") and len(m.frames) >= i) - plot_frames(tuple(m.frames[i] for m in self.methods if - hasattr(m, "frames") and len(m.frames) >= i), + labels = tuple(f"{m['name']} frame" for m in self.methods if + "frames" in m and len(m["frames"]) >= i) + plot_frames(tuple(m["frames"][i] for m in self.methods if + "frames" in m and len(m["frames"]) >= i), dpi=dpi, label=labels, vmax=vmax, vmin=vmin, axis=axis, grid=grid) - plot_frames(tuple(m.probmaps[i] for m in self.methods), dpi=dpi, - label=tuple(['{} S/Nmap'.format(m.name) for m in + plot_frames(tuple(m["probmaps"][i] for m in self.methods), dpi=dpi, + label=tuple([f"{m['name']} S/Nmap" for m in self.methods]), axis=axis, grid=grid) for m in self.methods: - msg = '{} detection: {}, FPs: {}' - print(msg.format(m.name, m.detections[i][thr], m.fps[i][thr])) + print(f"{m['name']} detection: {m['detections'][i][thr]}, FPs: {m['fps'][i][thr]}") - labels = tuple('Thresholded at {:.1f}'.format(m.thresholds[thr]) + labels = tuple(f"Thresholded at {m['thresholds'][thr]:.1f}" for m in self.methods) - plot_frames(tuple(m.bmaps[i][thr] for m in self.methods), + plot_frames(tuple(m["bmaps"][i][thr] for m in self.methods), dpi=dpi, label=labels, axis=axis, grid=grid, colorbar=False, cmap='bone') else: @@ -342,7 +338,7 @@ def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None, # "SODIRF": dict(color="#9467bd", symbol="s"), # "SODINN": dict(color="#1f77b4", symbol="p"), # "SODINN-pw": dict(color="#1f77b4", symbol="p") - # } # maps m.name to plot style + # } # maps m["name"] to plot style for i, m in enumerate(self.methods): @@ -350,32 +346,32 @@ def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None, raise AttributeError("method #{} has no detections/fps. Run" "`compute_tpr_fps` first.".format(i)) - m.tpr = np.zeros((n_thresholds)) - m.mean_fps = np.zeros((n_thresholds)) + m["tpr"] = np.zeros(n_thresholds) + m["mean_fps"] = np.zeros(n_thresholds) for j in range(n_thresholds): - m.tpr[j] = np.asarray(m.detections)[:, j].tolist().count(1) / \ + m["tpr"][j] = np.asarray(m["detections"])[:, j].tolist().count(1) / \ self.n_injections - m.mean_fps[j] = np.asarray(m.fps)[:, j].mean() + m["mean_fps"][j] = np.asarray(m["fps"])[:, j].mean() - plt.plot(m.mean_fps, m.tpr, '--', color=m.color, **linekw) - plt.plot(m.mean_fps, m.tpr, m.symbol, label=m.name, color=m.color, + plt.plot(m["mean_fps"], m["tpr"], '--', color=m["color"], **linekw) + plt.plot(m["mean_fps"], m["tpr"], m["symbol"], label=m["name"], color=m["color"], **markerkw) if show_data_labels: if label_skip_one[i]: - lab_x = m.mean_fps[1::2] - lab_y = m.tpr[1::2] - thr = m.thresholds[1::2] + lab_x = m["mean_fps"][1::2] + lab_y = m["tpr"][1::2] + thr = m["thresholds"][1::2] else: - lab_x = m.mean_fps - lab_y = m.tpr - thr = m.thresholds + lab_x = m["mean_fps"] + lab_y = m["tpr"] + thr = m["thresholds"] for i, xy in enumerate(zip(lab_x + label_gap[0], lab_y + label_gap[1])): labels.append(ax.annotate('{:.2f}'.format(thr[i]), - xy=xy, xycoords='data', color=m.color, + xy=xy, xycoords='data', color=m["color"], **labelskw)) # TODO: reverse order of `self.methods` for better annot. # z-index?