Skip to content

Commit f99287f

Browse files
authored
Merge pull request #18 from lincc-frameworks/issue/15/todos
Address various todos from comments; add unit tests for skymap writer
2 parents 7402393 + 0d57b12 commit f99287f

6 files changed

Lines changed: 250 additions & 67 deletions

File tree

src/skymap_convert/skymap_readers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class ConvertedSkymapReader:
3131
file_path : Path
3232
Path to the skymap directory.
3333
safe_loading : bool
34-
Whether to verify polygon non-degeneracy when loading vertices.
34+
Whether to verify polygon non-degeneracy when loading vertices. Additionally, if true,
35+
checks that the metadata matches the array shapes.
3536
"""
3637

3738
def __init__(self, file_path: str | Path = None, safe_loading: bool = False, preset: str = None):
@@ -47,6 +48,15 @@ def __init__(self, file_path: str | Path = None, safe_loading: bool = False, pre
4748
preset : str, optional
4849
Name of a built-in skymap preset to load. If specified, file_path is ignored.
4950
Available presets can be listed with skymap_convert.presets.list_available_presets().
51+
52+
Raises
53+
------
54+
ValueError
55+
If neither file_path nor preset is provided, or if the specified preset does not exist.
56+
FileNotFoundError
57+
If the specified file_path does not exist or is not a directory.
58+
AssertionError
59+
If safe_loading is True and the metadata does not match the array shapes.
5060
"""
5161
# Use preset if provided, otherwise check for file_path
5262
if preset is not None:
@@ -87,7 +97,12 @@ def __init__(self, file_path: str | Path = None, safe_loading: bool = False, pre
8797
self.tracts = np.load(self.tracts_path, mmap_mode="r")
8898
self.patches = np.load(self.patches_path, mmap_mode="r")
8999

90-
# TODO could be nice to check if the metadata matches the arrays here.
100+
# Check if the metadata matches the arrays
101+
if self.safe_loading:
102+
assert self.n_tracts == self.tracts.shape[0], "Metadata n_tracts does not match array shape"
103+
assert (
104+
self.n_patches_per_tract == self.patches.shape[1]
105+
), "Metadata n_patches_per_tract does not match array shape"
91106

92107
def _decompress_patches_gz(self) -> Path:
93108
"""Decompress patches.npy.gz to a temporary file if not already done.

src/skymap_convert/skymap_writers.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def write(self, skymap, output_path: str | Path, skymap_name: str = "converted_s
4646
4747
Notes
4848
-----
49-
The method assumes 100 patches per tract, which is standard for LSST skymaps.
50-
All polygon vertices are stored as [RA, Dec] coordinates in degrees.
49+
- The method assumes all tracts have the same number of patches, which is standard for LSST
50+
skymaps.
51+
- All polygon vertices are stored as [RA, Dec] coordinates in degrees.
5152
5253
Examples
5354
--------
@@ -59,8 +60,24 @@ def write(self, skymap, output_path: str | Path, skymap_name: str = "converted_s
5960
output_path = Path(output_path)
6061
self._ensure_output_directory(output_path)
6162

62-
n_tracts = len(skymap) # TODO - is this accounting for poles?
63-
n_patches = 100 # fixed per tract # TODO - would be nice to make this dynamic, though we expect 100
63+
n_tracts = skymap._numTracts
64+
65+
# Get number of patches from the first tract
66+
first_tract = skymap[0]
67+
n_patches_0 = (
68+
first_tract._tractBuilder._numPatches.x * first_tract._tractBuilder._numPatches.y
69+
) # Number of patches per tract, typically 100.
70+
# Get the number of patches from the second tract
71+
second_tract = skymap[1]
72+
n_patches_1 = second_tract._tractBuilder._numPatches.x * second_tract._tractBuilder._numPatches.y
73+
# Check that they are the same
74+
if n_patches_0 != n_patches_1:
75+
raise ValueError(
76+
f"Number of patches per tract mismatch: {n_patches_0} != {n_patches_1}. "
77+
"Ensure all tracts have the same number of patches."
78+
)
79+
n_patches = n_patches_0 # Use the first tract's patch count
80+
6481
tract_array = np.zeros((n_tracts, 4, 2), dtype=np.float64)
6582
patch_array = np.zeros((n_tracts, n_patches, 4, 2), dtype=np.float64)
6683

src/skymap_convert/utils.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -118,66 +118,6 @@ def radians_to_degrees(radians):
118118
return radians * (180.0 / math.pi)
119119

120120

121-
class IterateTractAndRing:
122-
"""An iterator to traverse through tract IDs and ring IDs in a skymap.
123-
124-
This iterator yields tuples of (tract_index, ring_index) for each tract in the skymap.
125-
The first tract is the south pole (ring -1), and the last tract is the north pole.
126-
127-
Parameters
128-
----------
129-
ring_nums : list of int
130-
A list where each element represents the number of tracts in each ring.
131-
For example, [5, 10, 15] represents 5 tracts in ring 0, 10 in ring 1,
132-
and 15 in ring 2.
133-
add_poles : bool, optional
134-
If True, include the south pole (tract 0) and north pole (last tract)
135-
in the iteration. If False, only iterate through the rings (default: True).
136-
137-
Examples
138-
--------
139-
>>> iterator = IterateTractAndRing([5, 10, 15], add_poles=True)
140-
>>> for tract_id, ring_id in iterator:
141-
... print(f"Tract {tract_id} is in ring {ring_id}")
142-
Tract 0 is in ring -1
143-
Tract 1 is in ring 0
144-
...
145-
146-
TODO : this is more or less deprecated, I belive. Consider removing.
147-
"""
148-
149-
def __init__(self, ring_nums, add_poles=True):
150-
self.ring_nums = ring_nums
151-
if add_poles:
152-
self.total_tracts = sum(ring_nums) + 2
153-
self.current_tract = 0
154-
self.current_ring = -1
155-
else:
156-
self.total_tracts = sum(ring_nums)
157-
self.current_tract = 1
158-
self.current_ring = 0
159-
160-
def __iter__(self):
161-
return self
162-
163-
def __next__(self):
164-
# End iteration if we have processed all tracts.
165-
if self.current_tract >= self.total_tracts:
166-
raise StopIteration
167-
tract_and_ring = (self.current_tract, self.current_ring)
168-
169-
# Increase tract.
170-
self.current_tract += 1
171-
172-
# Check if we need to move to the next ring.clear
173-
if self.current_ring == -1:
174-
self.current_ring += 1
175-
elif self.current_tract > sum(self.ring_nums[: self.current_ring + 1]):
176-
self.current_ring += 1
177-
178-
return tract_and_ring
179-
180-
181121
def get_poly_from_tract_id(skymap, tract_id, inner=False):
182122
"""Get the ConvexPolygon for a tract by its ID.
183123

tests/skymap_convert/conftest.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,15 @@ def converted_skymap_reader():
5252
Note
5353
----
5454
This uses the pre-converted LSST skymap for performance reasons.
55+
The reader will be automatically cleaned up at the end of the test session.
5556
"""
57+
reader = ConvertedSkymapReader(preset="lsst_skymap")
5658

57-
return ConvertedSkymapReader(preset="lsst_skymap")
59+
# Yield the reader for use in tests
60+
yield reader
61+
62+
# Cleanup after all tests are done
63+
reader.cleanup()
5864

5965

6066
@pytest.fixture(scope="session")
@@ -74,3 +80,59 @@ def lsst_skymap():
7480
"""
7581
pytest.importorskip("lsst.skymap")
7682
return load_pickle_skymap(RAW_SKYMAP_DIR / "skyMap_lsst_cells_v1_skymaps.pickle")
83+
84+
85+
@pytest.fixture(scope="session")
86+
def written_skymap_data(lsst_skymap, tmp_path_factory, request):
87+
"""Session-scoped fixture that writes skymap data once and provides paths for testing.
88+
89+
Parameters
90+
----------
91+
lsst_skymap : lsst.skymap.SkyMap
92+
The original LSST skymap object to write.
93+
tmp_path_factory : pytest.TempPathFactory
94+
Pytest factory for creating temporary directories.
95+
request : pytest.FixtureRequest
96+
Pytest request object to access command line options.
97+
98+
Returns
99+
-------
100+
dict
101+
Dictionary containing:
102+
- 'output_dir': Path to the directory containing written skymap files
103+
- 'skymap_name': Name used for the skymap
104+
- 'reader': ConvertedSkymapReader instance for the written data
105+
106+
Notes
107+
-----
108+
This fixture only runs if --longrun is specified, otherwise it skips.
109+
The expensive skymap writing operation happens only once per test session.
110+
The reader will be automatically cleaned up at the end of the test session.
111+
"""
112+
# Skip if not running longrun tests
113+
if not request.config.option.longrun:
114+
pytest.skip("Skipping written_skymap_data fixture - requires --longrun")
115+
116+
pytest.importorskip("lsst.skymap")
117+
pytest.importorskip("lsst.sphgeom")
118+
119+
from skymap_convert.skymap_writers import ConvertedSkymapWriter
120+
121+
# Create session-scoped temporary directory
122+
output_dir = tmp_path_factory.mktemp("session_skymap_data")
123+
skymap_name = "test_session_skymap"
124+
125+
# Write the skymap once
126+
writer = ConvertedSkymapWriter()
127+
writer.write(lsst_skymap, output_dir, skymap_name)
128+
129+
# Create reader for the written data
130+
reader = ConvertedSkymapReader(output_dir)
131+
132+
data = {"output_dir": output_dir, "skymap_name": skymap_name, "reader": reader}
133+
134+
# Yield the data for use in tests
135+
yield data
136+
137+
# Cleanup after all tests are done
138+
reader.cleanup()
File renamed without changes.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import pytest
2+
import yaml
3+
from skymap_convert.utils import (
4+
get_quad_from_patch_id,
5+
get_quad_from_tract_id,
6+
quads_are_equiv,
7+
)
8+
from tqdm import tqdm
9+
10+
TRACT_SAMPLES = [1, 250, 1899]
11+
PATCH_SAMPLES = [0, 42, 99]
12+
13+
14+
@pytest.mark.longrun
15+
def test_writer_creates_expected_files(lsst_skymap, written_skymap_data):
16+
"""Test that ConvertedSkymapWriter creates all expected output files.
17+
18+
Parameters
19+
----------
20+
lsst_skymap : lsst.skymap.SkyMap
21+
The original LSST skymap object to write.
22+
written_skymap_data : dict
23+
Fixture providing written skymap data with output_dir, skymap_name, and reader.
24+
25+
Notes
26+
-----
27+
Verifies that:
28+
- All expected files (metadata.yaml, tracts.npy, patches.npy.gz) are created
29+
- No temporary files are left behind
30+
- Files have non-zero size
31+
"""
32+
pytest.importorskip("lsst.skymap")
33+
pytest.importorskip("lsst.sphgeom")
34+
35+
output_dir = written_skymap_data["output_dir"]
36+
skymap_name = written_skymap_data["skymap_name"]
37+
38+
# Check expected files exist
39+
assert (output_dir / "metadata.yaml").exists()
40+
assert (output_dir / "tracts.npy").exists()
41+
assert (output_dir / "patches.npy.gz").exists()
42+
43+
# Check temporary files are cleaned up
44+
assert not (output_dir / "patches.npy").exists()
45+
46+
# Check files have content
47+
assert (output_dir / "metadata.yaml").stat().st_size > 0
48+
assert (output_dir / "tracts.npy").stat().st_size > 0
49+
assert (output_dir / "patches.npy.gz").stat().st_size > 0
50+
51+
# Load and check metadata
52+
with open(output_dir / "metadata.yaml", "r") as f:
53+
metadata = yaml.safe_load(f)
54+
55+
assert metadata["name"] == skymap_name
56+
assert metadata["n_tracts"] == len(lsst_skymap)
57+
assert metadata["n_patches_per_tract"] == 100
58+
assert metadata["format_version"] == 1
59+
assert "generated" in metadata
60+
assert metadata["generated"].endswith("Z") # ISO format with UTC
61+
62+
63+
@pytest.mark.longrun
64+
def test_writer_sample_tracts(lsst_skymap, written_skymap_data):
65+
"""Test complete round-trip: write with ConvertedSkymapWriter, read with ConvertedSkymapReader.
66+
67+
Parameters
68+
----------
69+
lsst_skymap : lsst.skymap.SkyMap
70+
The original LSST skymap object.
71+
written_skymap_data : dict
72+
Fixture providing written skymap data with output_dir, skymap_name, and reader.
73+
74+
Notes
75+
-----
76+
Verifies that data written by ConvertedSkymapWriter can be correctly
77+
read by ConvertedSkymapReader and produces identical results.
78+
"""
79+
pytest.importorskip("lsst.skymap")
80+
pytest.importorskip("lsst.sphgeom")
81+
82+
reader = written_skymap_data["reader"]
83+
skymap_name = written_skymap_data["skymap_name"]
84+
85+
# Verify metadata matches
86+
assert reader.metadata["name"] == skymap_name
87+
assert reader.n_tracts == len(lsst_skymap)
88+
assert reader.n_patches_per_tract == 100
89+
90+
# Verify a few sample geometries match original
91+
for tract_id in TRACT_SAMPLES:
92+
original_tract = get_quad_from_tract_id(lsst_skymap, tract_id, inner=True)
93+
round_trip_tract = reader.get_tract_vertices(tract_id)
94+
95+
assert quads_are_equiv(original_tract, round_trip_tract), f"Round-trip failed for tract {tract_id}"
96+
97+
for patch_id in PATCH_SAMPLES:
98+
original_patch = get_quad_from_patch_id(lsst_skymap, tract_id, patch_id)
99+
round_trip_patch = reader.get_patch_vertices(tract_id, patch_id)
100+
101+
assert quads_are_equiv(
102+
original_patch, round_trip_patch
103+
), f"Round-trip failed for patch {patch_id} in tract {tract_id}"
104+
105+
106+
@pytest.mark.longrun
107+
def test_writer_full_skymap_integrity(lsst_skymap, written_skymap_data):
108+
"""Comprehensive test that writer preserves all tract and patch geometries.
109+
110+
Parameters
111+
----------
112+
lsst_skymap : lsst.skymap.SkyMap
113+
The original LSST skymap object.
114+
written_skymap_data : dict
115+
Fixture providing written skymap data with output_dir, skymap_name, and reader.
116+
117+
Notes
118+
-----
119+
This is a comprehensive test that verifies every tract and patch geometry
120+
is preserved correctly using the session-scoped written skymap data.
121+
"""
122+
if True:
123+
pytest.skip(
124+
"Skipping full integrity test as it takes an exceptionally long time to run. "
125+
"Recommend running periodically, especially after major changes to the writer. "
126+
"To run, manually remove the branching logic in the code of the test."
127+
)
128+
else:
129+
pytest.importorskip("lsst.skymap")
130+
pytest.importorskip("lsst.sphgeom")
131+
132+
reader = written_skymap_data["reader"]
133+
134+
tract_ids = range(len(lsst_skymap))
135+
for tract_id in tqdm(tract_ids, desc="Verifying written tracts", leave=False):
136+
# Verify tract geometry
137+
truth_quad = get_quad_from_tract_id(lsst_skymap, tract_id, inner=True)
138+
written_quad = reader.get_tract_vertices(tract_id)
139+
140+
assert quads_are_equiv(truth_quad, written_quad), f"Tract {tract_id} geometry not preserved"
141+
142+
# Verify all patch geometries
143+
for patch_id in range(100):
144+
truth_patch = get_quad_from_patch_id(lsst_skymap, tract_id, patch_id)
145+
written_patch = reader.get_patch_vertices(tract_id, patch_id)
146+
147+
assert quads_are_equiv(
148+
truth_patch, written_patch
149+
), f"Patch {patch_id} in tract {tract_id} geometry not preserved"

0 commit comments

Comments
 (0)