diff --git a/starfish/core/types/_spot_finding_results.py b/starfish/core/types/_spot_finding_results.py index b3dc71d65..a9a910262 100644 --- a/starfish/core/types/_spot_finding_results.py +++ b/starfish/core/types/_spot_finding_results.py @@ -1,12 +1,13 @@ +import json +import os from dataclasses import dataclass -from typing import Any, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple +from typing import Any, Dict, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple import xarray as xr from starfish.core.types import Axes, Coordinates, SpotAttributes from starfish.core.util.logging import Log - AXES_ORDER = (Axes.ROUND, Axes.CH) @@ -109,6 +110,98 @@ def values(self): """ return self._results.values() + def save(self, output_dir_name: str) -> None: + """Save spot finding results to series of files. + + Parameters + ---------- + output_dir_name: str + Location to save all files. + + """ + json_data: Dict[str, Any] = {} + + pwd = os.getcwd() + os.chdir(os.path.dirname(output_dir_name)) + base_name = os.path.basename(output_dir_name) + + coords = {} + for key in self.physical_coord_ranges.keys(): + path = "{}coords_{}.nc".format(base_name, key) + coords[key] = path + self.physical_coord_ranges[key].to_netcdf(path) + json_data["physical_coord_ranges"] = coords + + path = "{}log.arr" + json_data["log"] = {} + json_data["log"]["path"] = path.format(base_name) + with open(path.format(base_name), "w") as f: + f.write(self.log.encode()) + + spot_attrs = {} + for key in self._results.keys(): + path = "{}spots_{}_{}.nc".format(base_name, key[0], key[1]) + spot_attrs["{}_{}".format(key[0], key[1])] = path + self._results[key].spot_attrs.save(path) + json_data["spot_attrs"] = spot_attrs + + save = json.dumps(json_data) + with open("{}SpotFindingResults.json".format(base_name), "w") as f: + f.write(save) + + os.chdir(pwd) + + @classmethod + def load(cls, json_file: str): + """Load serialized spot finding results. + + Parameters: + ----------- + json_file: str + json file to read + + Returns: + -------- + SpotFindingResults: + Object containing loaded results + + """ + fl = open(json_file) + data = json.load(fl) + pwd = os.getcwd() + + os.chdir(os.path.dirname(json_file)) + + with open(data["log"]["path"]) as f: + txt = json.load(f)['log'] + txt = json.dumps(txt) + log = Log.decode(txt) + + rename_axes = { + 'x': Coordinates.X.value, + 'y': Coordinates.Y.value, + 'z': Coordinates.Z.value + } + coords = {} + for coord, path in data["physical_coord_ranges"].items(): + coords[rename_axes[coord]] = xr.load_dataarray(path) + + spot_attributes_list = [] + for key, path in data["spot_attrs"].items(): + zero = int(key.split("_")[0]) + one = int(key.split("_")[1]) + index = {AXES_ORDER[0]: zero, AXES_ORDER[1]: one} + spots = SpotAttributes.load(path) + spot_attributes_list.append((PerImageSliceSpotResults(spots, extras=None), index)) + + os.chdir(pwd) + + return SpotFindingResults( + imagestack_coords=coords, + log=log, + spot_attributes_list=spot_attributes_list + ) + @property def round_labels(self): """ diff --git a/starfish/core/types/test/test_saving_spots.py b/starfish/core/types/test/test_saving_spots.py new file mode 100644 index 000000000..9e8c5e157 --- /dev/null +++ b/starfish/core/types/test/test_saving_spots.py @@ -0,0 +1,69 @@ +import os +import tempfile + +import numpy as np +import pandas as pd +import xarray as xr + +from starfish.types import Axes, Coordinates, Features +from starfish.core.types import PerImageSliceSpotResults, SpotAttributes, SpotFindingResults +from starfish.core.util.logging import Log + +def dummy_spots() -> SpotFindingResults: + rounds = 4 + channels = 3 + spot_count = 100 + img_dim = {'x': 2048, 'y': 2048, 'z': 29} + + coords = {} + renameAxes = { + 'x': Coordinates.X.value, + 'y': Coordinates.Y.value, + 'z': Coordinates.Z.value + } + for dim in img_dim.keys(): + coords[renameAxes[dim]] = xr.DataArray(np.arange(0, 1, img_dim[dim])) + + log = Log() + + spot_attributes_list = [] + for r in range(rounds): + for c in range(channels): + index = {Axes.ROUND: r, Axes.CH: c} + spots = SpotAttributes(pd.DataFrame( + np.random.randint(0, 100, size=(spot_count, 4)), + columns=[Axes.X.value, + Axes.Y.value, + Axes.ZPLANE.value, + Features.SPOT_RADIUS] + )) + spot_attributes_list.append( + (PerImageSliceSpotResults(spots, extras=None), index) + ) + + return SpotFindingResults( + imagestack_coords=coords, + log=log, + spot_attributes_list=spot_attributes_list + ) + +def test_saving_spots() -> None: + data = dummy_spots() + + # test serialization + tempdir = tempfile.mkdtemp() + print(tempdir) + data.save(tempdir + "/") + + # load back into memory + data2 = SpotFindingResults.load(os.path.join(tempdir, 'SpotFindingResults.json')) + + # ensure all items are equal + assert data.keys() == data2.keys() + assert data._log.encode() == data2._log.encode() + for ax in data.physical_coord_ranges.keys(): + np.testing.assert_equal(data.physical_coord_ranges[ax].to_numpy(), + data2.physical_coord_ranges[ax].to_numpy()) + for k in data._results.keys(): + np.testing.assert_array_equal(data._results[k].spot_attrs.data, + data2._results[k].spot_attrs.data)