From 14537801fae413925d0640dc4ef7feaeca1357ba Mon Sep 17 00:00:00 2001 From: Xavier Garrido Date: Thu, 23 Feb 2023 16:43:33 +0100 Subject: [PATCH 1/2] Implement storage and bin selection for TopHatWindow (and inherited class) - the `get_section` method in now an empty method from `BaseWindow` class and is properly implemented for inherited class - when adding window with `add_ell_cl` the duplication of windows with corresponding window indices is also done for TopHatWindow and its inherited class such as `LogTopHatWindow` --- sacc/sacc.py | 18 +++++++----------- sacc/windows.py | 34 +++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/sacc/sacc.py b/sacc/sacc.py index 7ef40eb..40685ed 100644 --- a/sacc/sacc.py +++ b/sacc/sacc.py @@ -970,13 +970,8 @@ def get_bandpower_windows(self, indices): "tracers and data type) or get windows " "later.") ws = ws[0] - if not isinstance(ws, BandpowerWindow): - warnings.warn("No bandpower windows associated to these data") - return None - else: - w_inds = np.array(self._get_tags_by_index(['window_ind'], - indices)[0]) - return ws.get_section(w_inds) + w_inds = np.array(self._get_tags_by_index(['window_ind'], indices)[0]) + return ws.get_section(w_inds) def get_ell_cl(self, data_type, tracer1, tracer2, return_cov=False, return_ind=False): @@ -1200,12 +1195,13 @@ def add_ell_cl(self, data_type, tracer1, tracer2, ell, x, None """ - if isinstance(window, BandpowerWindow): - if len(ell) != window.nv: + if isinstance(window, (BandpowerWindow, TopHatWindow)): + nv = window.nv if isinstance(window, BandpowerWindow) else len(window.min) + if len(ell) != nv: raise ValueError("Input bandpowers are misshapen") - tag_extra = range(window.nv) + tag_extra = range(nv) tag_extra_name = "window_ind" - window_use = [window for i in range(window.nv)] + window_use = [window for i in range(nv)] else: tag_extra = None tag_extra_name = None diff --git a/sacc/windows.py b/sacc/windows.py index 4f7fcec..7085159 100644 --- a/sacc/windows.py +++ b/sacc/windows.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from astropy.table import Table @@ -78,6 +80,23 @@ def from_tables(cls, table_list): windows.update(subclass.from_table(table)) return windows + def get_section(self, indices): + """Get part of this window function corresponding to the input + indices. + + Parameters + ---------- + indices: int or array_like + Indices to return. + + Returns + ------- + window: `Window` + A `Window object. + """ + warnings.warn("No bandpower windows associated to these data") + return None + class TopHatWindow(BaseWindow, window_type='TopHat'): """A window function that is constant between two values. @@ -158,6 +177,8 @@ def from_table(cls, table): """ return {row['id']: cls(row['min'], row['max']) for row in table} + def get_section(self, indices): + return self.__class__(self.min[indices], self.max[indices]) class LogTopHatWindow(TopHatWindow, window_type='LogTopHat'): """A window function that is log-constant between two values. @@ -317,17 +338,4 @@ def from_table(cls, table): return {table.meta['SACCNAME']: cls(table['values'], table['weight'])} def get_section(self, indices): - """Get part of this window function corresponding to the input - indices. - - Parameters - ---------- - indices: int or array_like - Indices to return. - - Returns - ------- - window: `Window` - A `Window object. - """ return self.__class__(self.values, self.weight[:, indices]) From 1e1359d041297474f95414022fc33db8b5920a17 Mon Sep 17 00:00:00 2001 From: Xavier Garrido Date: Thu, 23 Feb 2023 18:10:07 +0100 Subject: [PATCH 2/2] fix import --- sacc/sacc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacc/sacc.py b/sacc/sacc.py index 40685ed..935fe9e 100644 --- a/sacc/sacc.py +++ b/sacc/sacc.py @@ -8,7 +8,7 @@ from astropy.table import Table from .tracers import BaseTracer -from .windows import BaseWindow, BandpowerWindow +from .windows import BaseWindow, BandpowerWindow, TopHatWindow from .covariance import BaseCovariance, concatenate_covariances from .utils import unique_list from .data_types import standard_types, DataPoint