Skip to content

Commit a888dcb

Browse files
committed
change on_deserialize flow
1 parent 1e3f10c commit a888dcb

File tree

4 files changed

+128
-89
lines changed

4 files changed

+128
-89
lines changed

stixcore/io/product_processors/fits/processors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,10 +823,11 @@ def generate_primary_header(self, filename, product, *, version=0):
823823
if default[0] not in soop_key_names:
824824
soop_headers += tuple([default])
825825

826+
scet_range = product.scet_timerange
826827
time_headers = (
827828
# Name, Value, Comment
828-
("OBT_BEG", product.scet_timerange.start.as_float().value, "Start of acquisition time in OBT"),
829-
("OBT_END", product.scet_timerange.end.as_float().value, "End of acquisition time in OBT"),
829+
("OBT_BEG", scet_range.start.as_float().value, "Start of acquisition time in OBT"),
830+
("OBT_END", scet_range.end.as_float().value, "End of acquisition time in OBT"),
830831
("TIMESYS", "UTC", "System used for time keywords"),
831832
("LEVEL", "L1", "Processing level of the data"),
832833
("DATE-OBS", product.utc_timerange.start.fits, "Start of acquisition time in UTC"),

stixcore/products/level3/flarelist.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from stixpy.coordinates.transforms import get_hpc_info
1212
from stixpy.net.client import STIXClient
1313
from stixpy.product import Product as STIXPYProduct
14-
from sunpy.coordinates import HeliographicStonyhurst, Helioprojective, SphericalScreen
14+
from sunpy.coordinates import HeliographicStonyhurst, Helioprojective
1515
from sunpy.map import make_fitswcs_header
1616
from sunpy.net import attrs as a
1717
from sunpy.time import TimeRange
@@ -77,7 +77,21 @@ def make_stix_fitswcs_header(data, flare_position, *, scale, exposure, rotation_
7777
return header
7878

7979

80-
class FlarePositionMixin:
80+
class _SerializeMixin:
81+
"""No-op chain terminator for on_serialize/on_deserialize.
82+
83+
Functional mixins inherit from this so super() calls always land safely
84+
instead of hitting object and raising AttributeError.
85+
"""
86+
87+
def on_serialize(self, data):
88+
pass
89+
90+
def on_deserialize(self, data, **kwargs):
91+
pass
92+
93+
94+
class FlarePositionMixin(_SerializeMixin):
8195
"""_summary_"""
8296

8397
@classmethod
@@ -101,8 +115,8 @@ def add_flare_position(
101115
data["cpd_path"] = Column(" " * 500, dtype=str, description="TDB")
102116
data["_position_status"] = Column(False, dtype=bool, description="TDB")
103117
data["_position_message"] = Column(" " * 500, dtype=str, description="TDB")
104-
tx_list, ty_list = [], []
105-
solo_x_list, solo_y_list, solo_z_list, peak_time_list = [], [], [], []
118+
# tx_list, ty_list = [], []
119+
solo_cartesian_list = []
106120

107121
to_remove = []
108122
pass_filter = 0
@@ -118,6 +132,7 @@ def add_flare_position(
118132
peak_time = row[peak_time_colname]
119133
start_time = row[start_time_colname]
120134
end_time = row[end_time_colname]
135+
logger.info(f"Processing flare {i}/{len(data)} at time {start_time} : {end_time} (peak at {peak_time})")
121136
if filter_function(row): # and i < 60:
122137
pass_filter += 1
123138
day = peak_time.to_datetime().date()
@@ -137,6 +152,8 @@ def add_flare_position(
137152
logger.warning(f"No ephemeris data found for flare at time {start_time} : {end_time}")
138153
data[i]["_position_message"] = "no ephemeris data found"
139154
no_ephemeris += 1
155+
solo_cartesian_list.append((np.nan * u.km, np.nan * u.km, np.nan * u.km))
156+
140157
continue
141158
data[i]["anc_ephemeris_path"] = anc_res["path"][0]
142159

@@ -153,6 +170,8 @@ def add_flare_position(
153170
logger.warning(f"No CPD data found for flare at time {start_time} : {end_time}")
154171
data[i]["_position_message"] = "no CPD data found"
155172
no_cpd += 1
173+
solo_cartesian_list.append((np.nan * u.km, np.nan * u.km, np.nan * u.km))
174+
156175
continue
157176
if len(cpd_res) > 1:
158177
logger.debug(f"Many CPD data found for flare at time {start_time} : {end_time}")
@@ -234,55 +253,48 @@ def add_flare_position(
234253

235254
roll, solo_xyz, pointing = get_hpc_info(start_time, end_time)
236255
solo = HeliographicStonyhurst(*solo_xyz, obstime=peak_time, representation_type="cartesian")
237-
with SphericalScreen(solo, only_off_disk=True):
238-
center_hpc = coord.transform_to(Helioprojective(observer=solo))
239-
tx_list.append(center_hpc.Tx)
240-
ty_list.append(center_hpc.Ty)
241-
solo_x_list.append(solo.cartesian.x)
242-
solo_y_list.append(solo.cartesian.y)
243-
solo_z_list.append(solo.cartesian.z)
244-
peak_time_list.append(peak_time)
256+
# with SphericalScreen(solo, only_off_disk=True):
257+
# center_hpc = coord.transform_to(Helioprojective(observer=solo))
258+
# tx_list.append(center_hpc.Tx)
259+
# ty_list.append(center_hpc.Ty)
260+
solo_cartesian_list.append((solo.cartesian.x, solo.cartesian.y, solo.cartesian.z))
245261

246262
data[i]["_position_status"] = True
247263
data[i]["_position_message"] = "OK"
248264
except Exception as e:
249265
data[i]["_position_status"] = False
250266
data[i]["_position_message"] = f"Error: {type(e)}"
251267
logger.warning(f"Error calculating flare position for flare at time {start_time} : {end_time}: {e}")
252-
tx_list.append(np.nan * u.arcsec)
253-
ty_list.append(np.nan * u.arcsec)
254-
solo_x_list.append(np.nan * u.km)
255-
solo_y_list.append(np.nan * u.km)
256-
solo_z_list.append(np.nan * u.km)
257-
peak_time_list.append(peak_time)
268+
# tx_list.append(np.nan * u.arcsec)
269+
# ty_list.append(np.nan * u.arcsec)
270+
solo_cartesian_list.append((np.nan * u.km, np.nan * u.km, np.nan * u.km))
271+
258272
else:
259273
to_remove.append(i)
260-
tx_list.append(np.nan * u.arcsec)
261-
ty_list.append(np.nan * u.arcsec)
262-
solo_x_list.append(np.nan * u.km)
263-
solo_y_list.append(np.nan * u.km)
264-
solo_z_list.append(np.nan * u.km)
265-
peak_time_list.append(peak_time)
274+
# tx_list.append(np.nan * u.arcsec)
275+
# ty_list.append(np.nan * u.arcsec)
276+
solo_cartesian_list.append((np.nan * u.km, np.nan * u.km, np.nan * u.km))
266277
data[i]["_position_status"] = False
267278
data[i]["_position_message"] = "flare did not pass the filter function"
268279

269-
solo_times = Time(peak_time_list)
280+
solo_times = Time(data[peak_time_colname])
281+
solo_x, solo_y, solo_z = zip(*solo_cartesian_list)
270282
hgs_coords = SkyCoord(
271-
u.Quantity(solo_x_list),
272-
u.Quantity(solo_y_list),
273-
u.Quantity(solo_z_list),
283+
u.Quantity(solo_x),
284+
u.Quantity(solo_y),
285+
u.Quantity(solo_z),
274286
frame=HeliographicStonyhurst(obstime=solo_times),
275287
representation_type="cartesian",
276288
)
277289

278-
hp_coords = SkyCoord(
279-
u.Quantity(tx_list), u.Quantity(ty_list), frame=Helioprojective(obstime=solo_times, observer=hgs_coords)
280-
)
290+
# hp_coords = SkyCoord(
291+
# u.Quantity(tx_list), u.Quantity(ty_list), frame=Helioprojective(obstime=solo_times, observer=hgs_coords)
292+
# )
281293

282294
data["location_hgs"] = hgs_coords
283295
# description="Flare location in Heliographic Stonyhurst coordinates"
284296

285-
data["location_hp"] = hp_coords
297+
# data["location_hp"] = hp_coords
286298
# description="Flare location in Helioprojective coordinates"
287299

288300
if not keep_all_flares:
@@ -297,35 +309,31 @@ def add_flare_position(
297309
)
298310

299311
def on_serialize(self, data):
300-
for col_name in ("location_hgs", "location_hp"):
301-
if col_name in data.colnames:
302-
icrs = data[col_name].icrs
303-
icrs_coord = SkyCoord(icrs.ra, icrs.dec, icrs.distance, frame="icrs")
304-
col_idx = data.colnames.index(col_name)
305-
data.remove_column(col_name)
306-
data.add_column(icrs_coord, name=col_name, index=col_idx)
307-
s = super()
308-
if hasattr(s, "on_serialize"):
309-
s.on_serialize(data)
310-
311-
def on_deserialize(self, data, *, peak_time_colname=None):
312+
logger.info("FlarePositionMixin on_serialize called, transforming location columns to ICRS for serialization")
313+
314+
if "location_hgs" in data.colnames:
315+
icrs = data["location_hgs"].icrs
316+
icrs_coord = SkyCoord(icrs.ra, icrs.dec, icrs.distance, frame="icrs")
317+
col_idx = data.colnames.index("location_hgs")
318+
data.remove_column("location_hgs")
319+
data.add_column(icrs_coord, name="location_icrs", index=col_idx)
320+
super().on_serialize(data)
321+
322+
def on_deserialize(self, data, *, peak_time_colname=None, **kwargs):
323+
logger.info(
324+
"FlarePositionMixin on_deserialize called, transforming location columns back to heliographic coordinates"
325+
)
312326
peak_col = peak_time_colname or self.peak_time_colname
313327
if peak_col not in data.colnames:
314328
logger.warning(f"on_deserialize: column '{peak_col}' not found, skipping location transform")
315329
else:
316330
obstime = Time(data[peak_col])
317-
if "location_hgs" in data.colnames:
318-
data["location_hgs"] = data["location_hgs"].transform_to(HeliographicStonyhurst(obstime=obstime))
319-
if "location_hp" in data.colnames:
320-
data["location_hp"] = data["location_hp"].transform_to(
321-
Helioprojective(obstime=obstime, observer=data["location_hgs"])
322-
)
323-
s = super()
324-
if hasattr(s, "on_deserialize"):
325-
s.on_deserialize(data)
331+
if "location_icrs" in data.colnames:
332+
data["location_hgs"] = data["location_icrs"].transform_to(HeliographicStonyhurst(obstime=obstime))
333+
super().on_deserialize(data, **kwargs)
326334

327335

328-
class FlareSOOPMixin:
336+
class FlareSOOPMixin(_SerializeMixin):
329337
"""_summary_"""
330338

331339
@classmethod
@@ -352,6 +360,14 @@ def add_soop(
352360
data["soop_id"] = Column(soop_id, dtype=str, description="SOOP ID")
353361
data["soop_type"] = Column(soop_type, dtype=str, description="name of the SOOP campaign")
354362

363+
# def on_serialize(self, data):
364+
# logger.info("FlareSOOPMixin on_serialize called, but no special handling implemented for SOOP data")
365+
# super().on_serialize(data)
366+
367+
# def on_deserialize(self, data, **kwargs):
368+
# logger.info("FlareSOOPMixin on_deserialize called, but no special handling implemented for SOOP data")
369+
# super().on_deserialize(data, **kwargs)
370+
355371

356372
class FlarePeakPreviewMixin:
357373
"""Mixin class to add peak preview images to flare list products.

stixcore/products/product.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -613,16 +613,21 @@ def max_exposure(self):
613613
return 0.0
614614

615615
def on_serialize(self, data):
616-
"""Hook called before writing data to FITS. Mixins override and chain via super()."""
617-
s = super()
618-
if hasattr(s, "on_serialize"):
619-
s.on_serialize(data)
616+
"""Hook called before writing data to FITS. Mixins override and chain via super().
620617
621-
def on_deserialize(self, data):
618+
Uses getattr so plain products without functional mixins are safe —
619+
the functional mixins (FlarePositionMixin etc.) appear after GenericProduct
620+
in the MRO, so pass would stop the chain before reaching them.
621+
"""
622+
serialize = getattr(super(), "on_serialize", None)
623+
if serialize is not None:
624+
serialize(data)
625+
626+
def on_deserialize(self, data, **kwargs):
622627
"""Hook called after reading data from FITS. Mixins override and chain via super()."""
623-
s = super()
624-
if hasattr(s, "on_deserialize"):
625-
s.on_deserialize(data)
628+
deserialize = getattr(super(), "on_deserialize", None)
629+
if deserialize is not None:
630+
deserialize(data, **kwargs)
626631

627632
def find_parent_products(self, root):
628633
"""

stixcore/products/tests/test_flarelist.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import numpy as np
44
import pytest
5-
from sunpy.coordinates import HeliographicStonyhurst, Helioprojective
5+
from sunpy.coordinates import HeliographicStonyhurst
66

77
import astropy.units as u
88
from astropy.coordinates import SkyCoord
99
from astropy.io import fits
1010
from astropy.table import QTable
11+
from astropy.tests.helper import assert_quantity_allclose
1112
from astropy.time import Time
1213

1314
from stixcore.io.product_processors.fits.processors import FitsL3Processor
@@ -21,32 +22,32 @@
2122
def flare_data():
2223
peak_times = Time("2022-01-01T12:00:00") + np.arange(N) * 600 * u.s
2324

25+
lon = np.linspace(0, 30, N)
26+
lat = np.linspace(-5, 5, N)
27+
# mark same rows fully NaN so the entire SkyCoord row is invalid
28+
lon[2] = lat[2] = np.nan
29+
lon[7] = lat[7] = np.nan
30+
2431
hgs_coords = SkyCoord(
25-
lon=np.linspace(0, 30, N) * u.deg,
26-
lat=np.linspace(-5, 5, N) * u.deg,
32+
lon=lon * u.deg,
33+
lat=lat * u.deg,
2734
radius=np.ones(N) * 1.0 * u.AU,
2835
frame=HeliographicStonyhurst(obstime=peak_times),
2936
)
3037

31-
hp_coords = SkyCoord(
32-
Tx=np.linspace(-300, 300, N) * u.arcsec,
33-
Ty=np.linspace(-200, 200, N) * u.arcsec,
34-
frame=Helioprojective(obstime=peak_times, observer=hgs_coords),
35-
)
36-
3738
data = QTable()
3839
data["peak_UTC"] = peak_times
3940
data["start_UTC"] = peak_times - 60 * u.s
4041
data["end_UTC"] = peak_times + 60 * u.s
4142
data["duration"] = np.ones(N) * 120 * u.s
4243
data["lc_peak"] = np.ones((N, 5)) * u.ct / u.s
4344
data["location_hgs"] = hgs_coords
44-
data["location_hp"] = hp_coords
4545

4646
return data
4747

4848

49-
def test_flarelist_sdcloc_location_roundtrip(flare_data, tmp_path):
49+
@pytest.fixture
50+
def written_fits(flare_data, tmp_path):
5051
prod = FlarelistSDCLoc(
5152
data=flare_data,
5253
month=date(2022, 1, 1),
@@ -61,26 +62,42 @@ def test_flarelist_sdcloc_location_roundtrip(flare_data, tmp_path):
6162
header["SSID"] = 3
6263
header["DATE-BEG"] = "2022-01-01T00:00:00"
6364
prod.fits_header = header
64-
65-
# energy/additional_header_keywords are not set for freshly created products
6665
prod.energy = None
6766
prod._additional_header_keywords = []
6867

68+
writer = FitsL3Processor(tmp_path)
69+
written = writer.write_fits(prod)
70+
assert len(written) == 1
71+
72+
return prod, written[0]
73+
74+
75+
def test_flarelist_sdcloc_location_roundtrip(written_fits):
76+
prod, fits_path = written_fits
6977
orig_hgs_lon = prod.data["location_hgs"].lon.copy()
7078
orig_hgs_lat = prod.data["location_hgs"].lat.copy()
71-
orig_hp_tx = prod.data["location_hp"].Tx.copy()
72-
orig_hp_ty = prod.data["location_hp"].Ty.copy()
73-
74-
# write via FitsL3Processor — calls on_serialize internally, prod.data unchanged
75-
writer = FitsL3Processor(tmp_path)
76-
written_file_name = writer.write_fits(prod)
77-
assert len(written_file_name) == 1
7879

7980
# read back via Product factory — calls on_deserialize internally
80-
recovered = Product(written_file_name[0])
81+
recovered = Product(fits_path)
8182

8283
assert isinstance(recovered, FlarelistSDCLoc)
83-
assert u.allclose(recovered.data["location_hgs"].lon, orig_hgs_lon, atol=1e-6 * u.deg)
84-
assert u.allclose(recovered.data["location_hgs"].lat, orig_hgs_lat, atol=1e-6 * u.deg)
85-
assert u.allclose(recovered.data["location_hp"].Tx, orig_hp_tx, atol=1e-3 * u.arcsec)
86-
assert u.allclose(recovered.data["location_hp"].Ty, orig_hp_ty, atol=1e-3 * u.arcsec)
84+
assert_quantity_allclose(recovered.data["location_hgs"].lon, orig_hgs_lon, atol=1e-6 * u.deg, equal_nan=True)
85+
assert_quantity_allclose(recovered.data["location_hgs"].lat, orig_hgs_lat, atol=1e-6 * u.deg, equal_nan=True)
86+
87+
88+
def test_flarelist_sdcloc_fits_stores_icrs(written_fits):
89+
prod, fits_path = written_fits
90+
orig_hgs_lon = prod.data["location_hgs"].lon.copy()
91+
orig_hgs_lat = prod.data["location_hgs"].lat.copy()
92+
93+
# read the DATA extension directly — no on_deserialize, raw FITS content
94+
raw = QTable.read(fits_path, hdu="DATA", astropy_native=True)
95+
96+
assert "location_hgs" not in raw.colnames, "HGS column should not be stored in FITS"
97+
assert "location_icrs" in raw.colnames, "ICRS column should be present in FITS"
98+
99+
# manually transform ICRS back to HGS and compare with original
100+
obstime = Time(raw["peak_UTC"])
101+
hgs = raw["location_icrs"].transform_to(HeliographicStonyhurst(obstime=obstime))
102+
assert_quantity_allclose(hgs.lon, orig_hgs_lon, atol=1e-6 * u.deg, equal_nan=True)
103+
assert_quantity_allclose(hgs.lat, orig_hgs_lat, atol=1e-6 * u.deg, equal_nan=True)

0 commit comments

Comments
 (0)