diff --git a/sotrplib/handlers/base.py b/sotrplib/handlers/base.py index 9773ae9d..1545c911 100644 --- a/sotrplib/handlers/base.py +++ b/sotrplib/handlers/base.py @@ -16,7 +16,7 @@ from sotrplib.maps.postprocessor import MapPostprocessor from sotrplib.maps.preprocessor import MapPreprocessor from sotrplib.outputs.core import SourceOutput -from sotrplib.sifter.core import EmptySifter, SiftingProvider +from sotrplib.sifter.core import EmptySifter, SifterResult, SiftingProvider from sotrplib.sims.sim_source_generators import ( SimulatedSource, SimulatedSourceGenerator, @@ -29,6 +29,7 @@ ForcedPhotometryProvider, ) from sotrplib.sources.force import EmptyForcedPhotometry +from sotrplib.sources.sources import MeasuredSource from sotrplib.sources.subtractor import EmptySourceSubtractor, SourceSubtractor __all__ = ["BaseRunner"] @@ -116,9 +117,6 @@ def build_map(self, input_map: ProcessableMap) -> ProcessableMap: return output_map - def coadd_maps(self, input_maps: list[ProcessableMap]) -> list[ProcessableMap]: - return self.map_coadder.coadd(input_maps) - @property def bbox(self): if not self.maps: @@ -168,9 +166,22 @@ def simulate_sources(self) -> list[SimulatedSource]: self.source_catalogs.append(catalog) return all_simulated_sources + def coadd_and_analyze_maps( + self, maps: list[ProcessableMap], simulated_sources: list[SimulatedSource] + ) -> tuple[list[MeasuredSource], SifterResult]: + """ + Coadd and analyze maps in a single task to avoid passing maps between processes. + """ + coadded_map = self.profilable_task(self.map_coadder.coadd_maps)(maps) + return self.profilable_task(self.analyze_map)( + input_map=coadded_map, simulated_sources=simulated_sources + ) + def analyze_map( self, input_map: ProcessableMap, simulated_sources: list[SimulatedSource] - ) -> tuple[list, object, ProcessableMap]: + ) -> tuple[list[MeasuredSource], SifterResult]: + input_map = self.profilable_task(self.build_map)(input_map) + self.profilable_task(input_map.finalize)() injected_sources, input_map = self.profilable_task(self.source_injector.inject)( @@ -232,28 +243,26 @@ def analyze_map( self.profilable_task(output.output)( forced_photometry_candidates=forced_photometry_candidates, sifter_result=sifter_result, - input_map=input_map, + map_id=input_map.map_id, pointing_sources=pointing_sources, injected_sources=injected_sources, ) if input_map._parent_database is not None: - set_processing_end(input_map.map_id) - return forced_photometry_candidates, sifter_result, input_map + self.profilable_task(set_processing_end)(input_map.map_id) + return forced_photometry_candidates, sifter_result - def run(self) -> tuple[list[list], list[object], list[ProcessableMap]]: + def run(self) -> tuple[list[list[MeasuredSource]], list[SifterResult]]: return self.flow(self._run)() - def _run(self) -> tuple[list[list], list[object], list[ProcessableMap]]: + def _run(self) -> tuple[list[list[MeasuredSource]], list[SifterResult]]: """ The actual pipeline run logic has to be in a separate method so that it can be decorated with the flow as prefect needs these to be defined in advance. """ all_simulated_sources = self.basic_task(self.simulate_sources)() - self.maps = self.basic_task(self.build_map).map(self.maps).result() - self.maps = [m for m in self.maps if m is not None] - self.maps = self.coadd_maps(self.maps) + self.map_sets = self.basic_task(self.map_coadder.group_maps)(self.maps) return ( - self.basic_task(self.analyze_map) - .map(self.maps, self.unmapped(all_simulated_sources)) + self.basic_task(self.coadd_and_analyze_maps) + .map(self.map_sets, self.unmapped(all_simulated_sources)) .result() ) diff --git a/sotrplib/maps/map_coadding.py b/sotrplib/maps/map_coadding.py index 4c47fe52..b66f10d0 100644 --- a/sotrplib/maps/map_coadding.py +++ b/sotrplib/maps/map_coadding.py @@ -17,12 +17,38 @@ class MapCoadder(ABC): and used interchangeably. """ + @abstractmethod + def coadd_maps(self, input_maps: list[ProcessableMap]) -> ProcessableMap: + """Coadd a list of ProcessableMap objects into a single coadded map.""" + return + + @abstractmethod + def group_maps( + self, input_maps: list[ProcessableMap] + ) -> list[list[ProcessableMap]]: + """ + Group input maps by frequency and array. Returns a list of lists of + ProcessableMap objects. + """ + return + @abstractmethod def coadd(self, input_maps: list[ProcessableMap]) -> list[ProcessableMap]: + """ + Coadd all maps according the coadder's grouping strategy and return a list of + coadded maps. + """ return class EmptyMapCoadder(MapCoadder): + def coadd_maps(self, input_maps: list[ProcessableMap]) -> ProcessableMap: + assert len(input_maps) == 1, "EmptyMapCoadder can only coadd a single map" + return input_maps[0] + + def group_maps(self, input_maps): + return [[input_map] for input_map in input_maps] + def coadd(self, input_maps: list[ProcessableMap]) -> list[ProcessableMap]: return input_maps @@ -56,6 +82,52 @@ def __init__( self.instrument = instrument self.log = log or structlog.get_logger() + def group_maps( + self, input_maps: list[ProcessableMap] + ) -> list[list[ProcessableMap]]: + """ + Group input maps by frequency and array. Returns a list of lists of maps. + Within each list, the maps all have the same frequency and array. + """ + map_sets = list() + for arr in self.arrays: + for freq in self.frequencies: + map_sets.append(self._get_valid_maps(input_maps, freq, arr)) + map_sets = [ms for ms in map_sets if ms] # filter out empty map sets + return map_sets + + def _get_valid_maps(self, input_maps: list[ProcessableMap], freq: str, arr: str): + """ + Identify which maps match the provided frequency and array. + """ + good_maps = [True] * len(input_maps) + for i, imap in enumerate(input_maps): + if imap.frequency != freq: + good_maps[i] = False + if arr != "coadd" and arr != imap.array: + good_maps[i] = False + + if not any(good_maps): + self.log.debug( + "rhokappamapcoadder.coadd.no_good_maps", + n_input_maps=len(input_maps), + frequency=freq, + array=arr, + ) + return list() + if not all(good_maps): + self.log.info( + "rhokappamapcoadder.coadd.dropping_maps", + n_input_maps=len(input_maps), + n_dropped_maps=len([good for good in good_maps if not good]), + frequency=freq, + array=arr, + ) + + valid_input_maps = [imap for imap, good in zip(input_maps, good_maps) if good] + + return valid_input_maps + def coadd(self, input_maps: list[ProcessableMap]): """ Coadd input_maps given the coadder freqs, arrays. @@ -82,127 +154,109 @@ def coadd(self, input_maps: list[ProcessableMap]): coadded_maps = [] for arr in self.arrays: for freq in self.frequencies: - good_maps = [True] * len(input_maps) - for i, imap in enumerate(input_maps): - if imap.frequency != freq: - good_maps[i] = False - if arr != "coadd" and arr != imap.array: - good_maps[i] = False - if not any(good_maps): - self.log.warning( - "rhokappamapcoadder.coadd.no_good_maps", - n_input_maps=len(input_maps), - frequency=freq, - array=arr, - ) + valid_input_maps = self._get_valid_maps(input_maps, freq, arr) + + if not valid_input_maps: continue - if not all(good_maps): - self.log.info( - "rhokappamapcoadder.coadd.dropping_maps", - n_input_maps=len(input_maps), - n_dropped_maps=len([good for good in good_maps if not good]), - frequency=freq, - array=arr, - ) - - valid_input_maps = [ - imap for imap, good in zip(input_maps, good_maps) if good - ] - - base_map = valid_input_maps[0] - base_map.build() - - coadd = CoaddedRhoKappaMap( - rho=base_map.rho, - kappa=base_map.kappa, - observation_start=base_map.observation_start, - observation_end=base_map.observation_end, - time_first=base_map.time_first, - time_mean=base_map.time_mean, - time_last=base_map.time_last, - observation_length=base_map.observation_end - - base_map.observation_start, - array=arr, - frequency=freq, - instrument=self.instrument, - flux_units=base_map.flux_units, - mask=base_map.mask, - map_resolution=base_map.map_resolution, - hits=base_map.hits, - map_ids=[base_map.map_id], - ) - self.log.info( - "rhokappamapcoadder.coadd.built", - map_start_time=coadd.observation_start, - map_end_time=coadd.observation_end, - frequency=freq, - array=arr, - ) + coadded_maps.append(self.coadd_maps(valid_input_maps)) - if len(valid_input_maps) == 1: - self.log.warning( - "rhokappamapcoadder.coadd.single_map_warning", n_maps_coadded=1 - ) - n_maps = 1 - coadded_maps.append(coadd) - continue + return coadded_maps - for sourcemap in valid_input_maps[1:]: - sourcemap.build() - self.log.info( - "rhokappamapcoadder.source_map.built", - map_start_time=sourcemap.observation_start, - map_end_time=sourcemap.observation_end, - map_frequency=sourcemap.frequency, - ) - ## will want to do a weighted sum using inverse variance. - if sourcemap.flux_units != coadd.flux_units: - flux_conv = u.Quantity(1.0, sourcemap.flux_units).to( - coadd.flux_units - ) - sourcemap.rho *= flux_conv - sourcemap.kappa /= flux_conv * flux_conv - - coadd.rho = enmap.map_union( - coadd.rho, - sourcemap.rho, - ) - coadd.kappa = enmap.map_union( - coadd.kappa, - sourcemap.kappa, - ) - if coadd.mask is not None and sourcemap.mask is not None: - coadd.mask = enmap.map_union( - coadd.mask, - enmap.enmap(sourcemap.mask), - ) - elif sourcemap.mask is not None: - coadd.mask = enmap.enmap(sourcemap.mask) - - coadd.update_times(sourcemap) - coadd._hits = enmap.map_union( - coadd.hits, - sourcemap.hits, - ) - coadd.map_ids.append(sourcemap.map_id) - - n_maps = len(coadd.input_map_times) - self.log.info( - "rhokappamapcoadder.coadd.completed", - n_maps_coadded=n_maps, - coadd_start_time=coadd.observation_start, - coadd_end_time=coadd.observation_end, - freq=freq, - arr=arr, - ) - coadd.observation_length = ( - coadd.observation_end - coadd.observation_start + def coadd_maps(self, valid_input_maps: list[ProcessableMap]) -> ProcessableMap: + base_map = valid_input_maps[0] + base_map.build() + freq = base_map.frequency + arr = base_map.array + + coadd = CoaddedRhoKappaMap( + rho=base_map.rho, + kappa=base_map.kappa, + observation_start=base_map.observation_start, + observation_end=base_map.observation_end, + time_first=base_map.time_first, + time_mean=base_map.time_mean, + time_last=base_map.time_last, + observation_length=base_map.observation_end - base_map.observation_start, + array=arr, + frequency=freq, + instrument=self.instrument, + flux_units=base_map.flux_units, + mask=base_map.mask, + map_resolution=base_map.map_resolution, + hits=base_map.hits, + map_ids=[base_map.map_id], + ) + + self.log.info( + "rhokappamapcoadder.coadd.built", + map_start_time=coadd.observation_start, + map_end_time=coadd.observation_end, + frequency=freq, + array=arr, + ) + + if len(valid_input_maps) == 1: + self.log.warning( + "rhokappamapcoadder.coadd.single_map_warning", n_maps_coadded=1 + ) + n_maps = 1 + return coadd + + for sourcemap in valid_input_maps[1:]: + sourcemap.build() + self.log.info( + "rhokappamapcoadder.source_map.built", + map_start_time=sourcemap.observation_start, + map_end_time=sourcemap.observation_end, + map_frequency=sourcemap.frequency, + ) + ## will want to do a weighted sum using inverse variance. + if sourcemap.flux_units != coadd.flux_units: + flux_conv = u.Quantity(1.0, sourcemap.flux_units).to(coadd.flux_units) + sourcemap.rho *= flux_conv + sourcemap.kappa /= flux_conv * flux_conv + + coadd.rho = enmap.map_union( + coadd.rho, + sourcemap.rho, + ) + coadd.kappa = enmap.map_union( + coadd.kappa, + sourcemap.kappa, + ) + if coadd.mask is not None and sourcemap.mask is not None: + coadd.mask = enmap.map_union( + coadd.mask, + enmap.enmap(sourcemap.mask), ) + elif sourcemap.mask is not None: + coadd.mask = enmap.enmap(sourcemap.mask) - if coadd.mask is not None: - coadd.mask[coadd.mask > 0] = 1 + coadd.update_times(sourcemap) + coadd._hits = enmap.map_union( + coadd.hits, + sourcemap.hits, + ) + coadd.map_ids.append(sourcemap.map_id) - coadded_maps.append(coadd) + n_maps = len(coadd.input_map_times) + self.log.info( + "rhokappamapcoadder.coadd.completed", + n_maps_coadded=n_maps, + coadd_start_time=coadd.observation_start, + coadd_end_time=coadd.observation_end, + freq=freq, + arr=arr, + ) + self.log.info( + "rhokappamapcoadder.sourcemap.time_mean", + time_mean=sourcemap.time_mean, + alt_time_mean=coadd.time_mean, + ) + coadd.observation_length = coadd.observation_end - coadd.observation_start - return coadded_maps + if coadd.mask is not None: + coadd.mask[coadd.mask > 0] = 1 + + return coadd diff --git a/sotrplib/outputs/core.py b/sotrplib/outputs/core.py index 8a8834ce..863405a5 100644 --- a/sotrplib/outputs/core.py +++ b/sotrplib/outputs/core.py @@ -10,7 +10,6 @@ import matplotlib.pyplot as plt -from sotrplib.maps.core import ProcessableMap from sotrplib.sifter.core import SifterResult from sotrplib.sims.sim_sources import SimulatedSource from sotrplib.sources.sources import MeasuredSource @@ -22,7 +21,7 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, + map_id: str, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): @@ -47,7 +46,6 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): @@ -83,18 +81,18 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, + map_id: str, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): filename = ( self.directory - / f"{input_map.get_map_str_id()}_{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d-%H-%M-%S')}.pickle" + / f"{map_id}_{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d-%H-%M-%S')}.pickle" ) with filename.open("wb") as handle: pickle.dump( obj={ - "map_id": input_map.get_map_str_id(), + "map_id": map_id, "forced_photometry": forced_photometry_candidates, "sifted_blind_search": sifter_result, "pointing_sources": pointing_sources, @@ -120,7 +118,7 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, + map_id: str, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): @@ -130,7 +128,9 @@ def output( if cutout is None: continue - filename = self.directory / f"forced_photometry_{source.source_id}.png" + filename = ( + self.directory / f"forced_photometry_{map_id}_{source.source_id}.png" + ) plt.imsave(fname=filename, arr=cutout, cmap="viridis") for ii, source in enumerate( @@ -143,9 +143,9 @@ def output( if source.crossmatches: id = source.crossmatches[0].source_id - filename = self.directory / f"blind_search_matched_{id}.png" + filename = self.directory / f"blind_search_matched_{map_id}_{id}.png" else: - filename = self.directory / f"blind_search_{ii}.png" + filename = self.directory / f"blind_search_{map_id}_{ii}.png" plt.imsave(fname=filename, arr=cutout, cmap="viridis") diff --git a/sotrplib/outputs/lightcurvedb.py b/sotrplib/outputs/lightcurvedb.py index a6606090..04fa4a2b 100644 --- a/sotrplib/outputs/lightcurvedb.py +++ b/sotrplib/outputs/lightcurvedb.py @@ -13,7 +13,6 @@ from structlog import get_logger from structlog.types import FilteringBoundLogger -from sotrplib.maps.core import ProcessableMap from sotrplib.sifter.core import SifterResult from sotrplib.sims.sim_sources import SimulatedSource from sotrplib.sources.sources import MeasuredSource @@ -229,7 +228,7 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, + map_id: str, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): @@ -238,8 +237,7 @@ def output( successful_uploads = asyncio.run( self._flux_upload_flow( forced_photometry_candidates, - map_time=input_map.observation_time, - map_id=input_map.map_id, + map_id=map_id, ) ) diff --git a/sotrplib/outputs/lightserve.py b/sotrplib/outputs/lightserve.py index 5e0bee75..79b69593 100644 --- a/sotrplib/outputs/lightserve.py +++ b/sotrplib/outputs/lightserve.py @@ -9,7 +9,6 @@ from structlog import get_logger from structlog.types import FilteringBoundLogger -from sotrplib.maps.core import ProcessableMap from sotrplib.sifter.core import SifterResult from sotrplib.sims.sim_sources import SimulatedSource from sotrplib.sources.sources import MeasuredSource @@ -49,7 +48,7 @@ def output( self, forced_photometry_candidates: list[MeasuredSource], sifter_result: SifterResult, - input_map: ProcessableMap, + map_id: str, pointing_sources: list[MeasuredSource] = [], # for compatibility injected_sources: list[SimulatedSource] = [], # for compatibility ): @@ -76,8 +75,7 @@ def output( frequency=90, module="i1", source_id=socat_to_internal[int(source.crossmatches[0].source_id)], - time=source.observation_mean_time.to_datetime() - or input_map.observation_time.to_datetime(), + time=source.observation_mean_time.to_datetime(), ra=source.ra.to_value("deg"), dec=source.dec.to_value("deg"), ra_uncertainty=( @@ -95,16 +93,15 @@ def output( else 0.0 ), extra={ - "map_id": input_map.map_id, + "map_id": map_id, }, ).model_dump_json() cut = ( Cutout( data=source.thumbnail.tolist(), - time=source.observation_mean_time.to_datetime() - or input_map.observation_time.to_datetime(), - units=input_map.flux_units.to_string(), + time=source.observation_mean_time.to_datetime(), + units=source.thumbnail_unit.to_string(), frequency=90, module="i1", ).model_dump_json() diff --git a/tests/regression/validate_results.py b/tests/regression/validate_results.py index bbdbc54e..939796de 100644 --- a/tests/regression/validate_results.py +++ b/tests/regression/validate_results.py @@ -61,7 +61,7 @@ def main(): profiler.write_html("profile.html") print(profiler.output_text(unicode=True, color=True)) result = results[0] - photometry, sifter_result, _ = result + photometry, sifter_result = result found = [candidate for candidate in photometry if not candidate.fit_failed] # some forced photometry candidates don't have offsets diff --git a/tests/test_prefect_pipeline/test_prefect_pipeline.py b/tests/test_prefect_pipeline/test_prefect_pipeline.py index 67d234fd..5f8416e1 100644 --- a/tests/test_prefect_pipeline/test_prefect_pipeline.py +++ b/tests/test_prefect_pipeline/test_prefect_pipeline.py @@ -8,7 +8,6 @@ from sotrplib.outputs.core import PickleSerializer from sotrplib.sifter.core import DefaultSifter, SifterResult from sotrplib.sims.maps import SimulatedMap -from sotrplib.sims.sources.core import ProcessableMapWithSimulatedSources from sotrplib.source_catalog.core import RegisteredSourceCatalog from sotrplib.sources.blind import SigmaClipBlindSearch from sotrplib.sources.force import Scipy2DGaussianFitter @@ -28,7 +27,7 @@ def test_basic_pipeline_scipy( runner = PrefectRunner( maps=maps, map_coadder=None, - source_catalogs=[], + source_catalogs=[source_cat], sso_catalogs=[], source_injector=None, preprocessors=None, @@ -50,8 +49,6 @@ def test_basic_pipeline_scipy( def _validate_pipeline_result(result, nmaps): assert len(result) == nmaps - print(result[0]) - candidates, sifter_result, output_map = result[0] + candidates, sifter_result = result[0] assert all([isinstance(candidate, MeasuredSource) for candidate in candidates]) assert isinstance(sifter_result, SifterResult) - assert isinstance(output_map, ProcessableMapWithSimulatedSources)