diff --git a/sacc/sacc.py b/sacc/sacc.py index 7ef40eb..935fe9e 100644 --- a/sacc/sacc.py +++ b/sacc/sacc.py @@ -8,7 +8,7 @@ from astropy.table import Table from .tracers import BaseTracer -from .windows import BaseWindow, BandpowerWindow +from .windows import BaseWindow, BandpowerWindow, TopHatWindow from .covariance import BaseCovariance, concatenate_covariances from .utils import unique_list from .data_types import standard_types, DataPoint @@ -970,13 +970,8 @@ def get_bandpower_windows(self, indices): "tracers and data type) or get windows " "later.") ws = ws[0] - if not isinstance(ws, BandpowerWindow): - warnings.warn("No bandpower windows associated to these data") - return None - else: - w_inds = np.array(self._get_tags_by_index(['window_ind'], - indices)[0]) - return ws.get_section(w_inds) + w_inds = np.array(self._get_tags_by_index(['window_ind'], indices)[0]) + return ws.get_section(w_inds) def get_ell_cl(self, data_type, tracer1, tracer2, return_cov=False, return_ind=False): @@ -1200,12 +1195,13 @@ def add_ell_cl(self, data_type, tracer1, tracer2, ell, x, None """ - if isinstance(window, BandpowerWindow): - if len(ell) != window.nv: + if isinstance(window, (BandpowerWindow, TopHatWindow)): + nv = window.nv if isinstance(window, BandpowerWindow) else len(window.min) + if len(ell) != nv: raise ValueError("Input bandpowers are misshapen") - tag_extra = range(window.nv) + tag_extra = range(nv) tag_extra_name = "window_ind" - window_use = [window for i in range(window.nv)] + window_use = [window for i in range(nv)] else: tag_extra = None tag_extra_name = None diff --git a/sacc/windows.py b/sacc/windows.py index 4f7fcec..7085159 100644 --- a/sacc/windows.py +++ b/sacc/windows.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from astropy.table import Table @@ -78,6 +80,23 @@ def from_tables(cls, table_list): windows.update(subclass.from_table(table)) return windows + def get_section(self, indices): + """Get part of this window function corresponding to the input + indices. + + Parameters + ---------- + indices: int or array_like + Indices to return. + + Returns + ------- + window: `Window` + A `Window object. + """ + warnings.warn("No bandpower windows associated to these data") + return None + class TopHatWindow(BaseWindow, window_type='TopHat'): """A window function that is constant between two values. @@ -158,6 +177,8 @@ def from_table(cls, table): """ return {row['id']: cls(row['min'], row['max']) for row in table} + def get_section(self, indices): + return self.__class__(self.min[indices], self.max[indices]) class LogTopHatWindow(TopHatWindow, window_type='LogTopHat'): """A window function that is log-constant between two values. @@ -317,17 +338,4 @@ def from_table(cls, table): return {table.meta['SACCNAME']: cls(table['values'], table['weight'])} def get_section(self, indices): - """Get part of this window function corresponding to the input - indices. - - Parameters - ---------- - indices: int or array_like - Indices to return. - - Returns - ------- - window: `Window` - A `Window object. - """ return self.__class__(self.values, self.weight[:, indices])