Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions sacc/covariance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from astropy.table import Table
from astropy.table import Table, Column
import scipy.linalg
import numpy as np
import warnings
Expand Down Expand Up @@ -218,9 +218,9 @@ def to_table(self):
table: astropy.table.Table instance
Table that can be used to reconstruct the object.
"""
col_names = [f'col_{i}' for i in range(self.size)]
cols = [self.covmat[i] for i in range(self.size)]
table = Table(data=cols, names=col_names)
# Store as a single vector column ('row') to avoid FITS TFIELDS>999
# Each table row is one row of the covariance matrix
table = Table([Column(name='row', data=self.covmat)])
table.meta['SIZE'] = self.size
return table

Expand All @@ -239,7 +239,11 @@ def from_table(cls, table):
Loaded covariance object
"""
size = table.meta['SIZE']
covmat = np.array([table[f'col_{i}'] for i in range(size)])
# Support both legacy many-column format and new single-column format
if 'row' in table.colnames:
covmat = np.array(list(table['row']))
else:
covmat = np.array([table[f'col_{i}'] for i in range(size)])
return cls(covmat)


Expand Down Expand Up @@ -375,8 +379,12 @@ def from_tables(cls, tables):
for i in range(nblock):
table = tables[f'block_{i}']
block_size = table.meta['SACCBSZE']
cols = [table[f'block_col_{i}'] for i in range(block_size)]
blocks.append(np.array(cols))
if 'block_row' in table.colnames:
block = np.array(list(table['block_row']))
else:
cols = [table[f'block_col_{j}'] for j in range(block_size)]
block = np.array(cols)
blocks.append(block)
return cls(blocks)

def to_tables(self):
Expand All @@ -396,9 +404,8 @@ def to_tables(self):
nblock = len(self.blocks)
for j, block in enumerate(self.blocks):
b = len(block)
col_names = [f'block_col_{i}' for i in range(b)]
cols = [block[i] for i in range(b)]
table = Table(data=cols, names=col_names)
# Use single vector column to minimize TFIELDS
table = Table([Column(name='block_row', data=block)])
table.meta['SIZE'] = self.size
table.meta['SACCBIDX'] = j
table.meta['SACCBCNT'] = nblock
Expand Down
38 changes: 26 additions & 12 deletions sacc/sacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from astropy.io import fits
from astropy.table import Table

# Module-level constants for Sacc file format versions
SACCFVER = 2 # Current FITS version
SACCHDF5VER = 1 # Current HDF5 version
import numpy as np

from .tracers import BaseTracer
Expand Down Expand Up @@ -940,10 +944,10 @@ def save_fits(self, filename, overwrite=False):

# Create the actual fits object
primary_header = fits.Header()
primary_header['SACCFVER'] = SACCFVER
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=fits.verify.VerifyWarning)
hdus = [fits.PrimaryHDU(header=primary_header)] + \
[fits.table_to_hdu(table) for table in tables]
hdus = [fits.PrimaryHDU(header=primary_header)] + [fits.table_to_hdu(table) for table in tables]
hdu_list = fits.HDUList(hdus)
io.astropy_buffered_fits_write(filename, hdu_list)

Expand All @@ -962,36 +966,34 @@ def load_fits(cls, filename):
"""
cov = None
metadata = None
fitsver = None

with fits.open(filename, mode="readonly") as f:
tables = []
for hdu in f:
for idx, hdu in enumerate(f):
if hdu.name.lower() == 'primary':
# The primary table is not a data table,
# but in older files it was used to store metadata
header = hdu.header
fitsver = header.get('SACCFVER', None)
if fitsver is None:
fitsver = 1
if fitsver > SACCFVER:
raise RuntimeError(f"Unsupported SACC FITS version: {fitsver}")
if "NMETA" in header:
metadata = {}
# Older format metadata is kept in the primary
# header, with keys KEY0, VAL0, etc.
n_meta = header['NMETA']
for i in range(n_meta):
k = header[f'KEY{i}']
v = header[f'VAL{i}']
metadata[k] = v
elif hdu.name.lower() == 'covariance':
# Legacy covariance - HDU will just be called covariance
# instead of the full name given by BaseIO.
# Note that this will also allow us to use multiple
# covariances in future.
cov = BaseCovariance.from_hdu(hdu)
else:
tables.append(Table.read(hdu))

# add the metadata table, if we are in the legacy format
if metadata is not None:
tables.append(io.metadata_to_table(metadata))

# Pass version to from_tables if needed (future-proofing)
return cls.from_tables(tables, cov=cov)

def save_hdf5(self, filename, overwrite=False, compression='gzip', compression_opts=4):
Expand Down Expand Up @@ -1027,6 +1029,8 @@ def save_hdf5(self, filename, overwrite=False, compression='gzip', compression_o
table.meta['EXTNAME'] = extname

with h5py.File(filename, 'w') as f:
# Write version dataset
f.create_dataset('sacc_hdf5_version', data=np.array([SACCHDF5VER], dtype='i4'))
used_names = {}
for table in tables:
# Build a meaningful dataset name
Expand Down Expand Up @@ -1087,9 +1091,19 @@ def load_hdf5(cls, filename):
"""
import h5py
recovered_tables = []
hdf5ver = None
with h5py.File(filename, 'r') as f:
# Check version
if 'sacc_hdf5_version' in f:
hdf5ver = int(np.array(f['sacc_hdf5_version'])[0])
else:
hdf5ver = 1
if hdf5ver > SACCHDF5VER:
raise RuntimeError(f"Unsupported SACC HDF5 version: {hdf5ver}")
# Read all datasets (not groups) in the order they appear
for key in f.keys():
if key == 'sacc_hdf5_version':
continue
item = f[key]
if isinstance(item, h5py.Dataset):
table = Table.read(f, path=key)
Expand Down
36 changes: 36 additions & 0 deletions test/test_bug_132.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import tempfile

import numpy as np

import sacc


def test_bug_132():
s = sacc.Sacc()

s.add_tracer(
"Map",
"tracer",
quantity="galaxy_density",
spin=0,
ell=np.arange(100),
beam=np.ones(100),
)

ndata = 1000
s.add_ell_cl("cl_00", "tracer", "tracer", np.arange(ndata), np.ones(ndata))
# Add covariance
s.add_covariance(np.eye(ndata))

# Write to SACC files using a temporary directory
with tempfile.TemporaryDirectory() as tmpdir:
fits_filename = os.path.join(tmpdir, "test.fits")
hdf5_filename = os.path.join(tmpdir, "test.hdf5")
s.save_fits(fits_filename, overwrite=True)
s.save_hdf5(hdf5_filename, overwrite=True)
# Read back in
s2 = sacc.Sacc.load_fits(fits_filename)
s3 = sacc.Sacc.load_hdf5(hdf5_filename)
assert s2 == s
assert s3 == s