Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sanitization functions and netcdf export (with cleanup) to FileIO #173

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 134 additions & 10 deletions src/PyHyperScattering/FileIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,149 @@ def __init__(self,xr_obj):
def savePickle(self,filename):
with open(filename, 'wb') as file:
pickle.dump(self._obj, file)


def sanitize_attrs(xr_obj):
"""
Sanitize the attributes of an xarray object to make them JSON serializable,
handling deeply nested dictionaries, lists, and array-like objects.

Parameters:
xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to sanitize.

Returns:
xarray.DataArray or xarray.Dataset: A copy of the input object with sanitized attributes.
"""
def sanitize_value(value):
"""Recursively sanitize a value to ensure JSON serializability."""
if isinstance(value, datetime):
return value.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(value, np.ndarray):
return value.tolist() # Convert numpy arrays to lists
elif hasattr(value, "__array__"): # Handles other array-like objects
return np.asarray(value).tolist()
elif isinstance(value, dict):
# Recursively sanitize dictionary values
return {k: sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
# Recursively sanitize list elements
return [sanitize_value(v) for v in value]
else:
try:
# Check if the value can be serialized to JSON
json.dumps(value)
return value
except (TypeError, OverflowError):
return None # Mark non-serializable values as None

sanitized_obj = xr_obj.copy()
sanitized_attrs = {}
dropped_attrs = {}

for key, value in sanitized_obj.attrs.items():
sanitized_value = sanitize_value(value)
if sanitized_value is not None:
sanitized_attrs[key] = sanitized_value
else:
dropped_attrs[key] = value

sanitized_obj.attrs = sanitized_attrs

# Print or log a summary of the sanitized attributes
if dropped_attrs:
print("Dropped non-serializable attributes:")
for key, value in dropped_attrs.items():
print(f" {key}: {type(value)} - {value}")
else:
print("No attributes were dropped.")

if sanitized_attrs:
print("\nConverted attributes:")
for key, value in sanitized_attrs.items():
print(f" {key}: {type(value)} -> {value}")

return sanitized_obj
def make_attrs_netcdf_safe(xr_obj):
"""
Make the attributes of an xarray object safe for NetCDF by JSON-encoding
dictionaries and other complex data types.

Parameters:
xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to process.

Returns:
xarray.DataArray or xarray.Dataset: A copy of the input object with NetCDF-safe attributes.
"""
def encode_complex(value):
"""
Encode complex data types (like dicts) into JSON strings.
"""
if isinstance(value, (dict, list, tuple)):
try:
# Convert to a JSON string
return json.dumps(value)
except (TypeError, OverflowError) as e:
# Handle unexpected cases gracefully
print(f"Error encoding attribute value: {value} ({e})")
return None
return value

sanitized_obj = xr_obj.copy()
encoded_attrs = {}

for key, value in sanitized_obj.attrs.items():
encoded_value = encode_complex(value)
if encoded_value is not None:
encoded_attrs[key] = encoded_value
else:
print(f"Dropping unsupported attribute: {key} -> {value}")

sanitized_obj.attrs = encoded_attrs

return sanitized_obj

# - This was copied from the Toney group contribution for GIWAXS.
def saveZarr(self, filename, mode: str = 'w'):
"""
Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix.
"""
Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix.
Parameters:
da (xr.DataArray): The DataArray to be saved.
base_path (Union[str, pathlib.Path]): The base path to save the .zarr file.
prefix (str): The prefix to use for the file name.
suffix (str): The suffix to use for the file name.
mode (str): The mode to use when saving the file. Default is 'w'.
"""
da = self._obj
ds = da.to_dataset(name='DA')
ds = self.sanitize_attrs(ds)
# unstack any multiindexes on the array
if hasattr(da, "indexes"):
multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)]
da = da.unstack(multiindexes) if multiindexes else da

file_path = pathlib.Path(filename)
ds.to_zarr(file_path, mode=mode)
def saveNetCDF(self, filename):
"""
Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix.

Parameters:
da (xr.DataArray): The DataArray to be saved.
base_path (Union[str, pathlib.Path]): The base path to save the .zarr file.
prefix (str): The prefix to use for the file name.
suffix (str): The suffix to use for the file name.
mode (str): The mode to use when saving the file. Default is 'w'.
"""
da = self._obj
ds = da.to_dataset(name='DA')
file_path = pathlib.Path(filename)
ds.to_zarr(file_path, mode=mode)

"""
da = self._obj
# sanitize attrs and make netcdf safe by converting dicts to json strings
da = self.sanitize_attrs(da)
da = self.make_attrs_netcdf_safe(da)
# unstack any multiindexes on the array
if hasattr(da, "indexes"):
multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)]
da = da.unstack(multiindexes) if multiindexes else da
file_path = pathlib.Path(filename)
da.to_netcdf(file_path)

def saveNexus(self,fileName,compression=5):
data = self._obj
timestamp = datetime.datetime.now()
Expand Down Expand Up @@ -309,4 +433,4 @@ def _make_coords(f):
else:
coords[axes[n]] = f['entry']['sasdata'][axis]

return coords
return coords
Loading