From 248cdbe8988e186e9901c1eaa033c5ba3d8869df Mon Sep 17 00:00:00 2001 From: Peter Beaucage Date: Thu, 23 Jan 2025 12:34:07 -0500 Subject: [PATCH 1/2] Add sanitization functions and netcdf export (with cleanup) --- src/PyHyperScattering/FileIO.py | 126 +++++++++++++++++++++++++++++++- 1 file changed, 123 insertions(+), 3 deletions(-) diff --git a/src/PyHyperScattering/FileIO.py b/src/PyHyperScattering/FileIO.py index 816c7165..9fd705ba 100644 --- a/src/PyHyperScattering/FileIO.py +++ b/src/PyHyperScattering/FileIO.py @@ -34,7 +34,105 @@ def __init__(self,xr_obj): def savePickle(self,filename): with open(filename, 'wb') as file: pickle.dump(self._obj, file) - + + def sanitize_attrs(xr_obj): + """ + Sanitize the attributes of an xarray object to make them JSON serializable, + handling deeply nested dictionaries, lists, and array-like objects. + + Parameters: + xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to sanitize. + + Returns: + xarray.DataArray or xarray.Dataset: A copy of the input object with sanitized attributes. + """ + def sanitize_value(value): + """Recursively sanitize a value to ensure JSON serializability.""" + if isinstance(value, datetime): + return value.isoformat() # Convert datetime to ISO 8601 string + elif isinstance(value, np.ndarray): + return value.tolist() # Convert numpy arrays to lists + elif hasattr(value, "__array__"): # Handles other array-like objects + return np.asarray(value).tolist() + elif isinstance(value, dict): + # Recursively sanitize dictionary values + return {k: sanitize_value(v) for k, v in value.items()} + elif isinstance(value, list): + # Recursively sanitize list elements + return [sanitize_value(v) for v in value] + else: + try: + # Check if the value can be serialized to JSON + json.dumps(value) + return value + except (TypeError, OverflowError): + return None # Mark non-serializable values as None + + sanitized_obj = xr_obj.copy() + sanitized_attrs = {} + dropped_attrs = {} + + for key, value in sanitized_obj.attrs.items(): + sanitized_value = sanitize_value(value) + if sanitized_value is not None: + sanitized_attrs[key] = sanitized_value + else: + dropped_attrs[key] = value + + sanitized_obj.attrs = sanitized_attrs + + # Print or log a summary of the sanitized attributes + if dropped_attrs: + print("Dropped non-serializable attributes:") + for key, value in dropped_attrs.items(): + print(f" {key}: {type(value)} - {value}") + else: + print("No attributes were dropped.") + + if sanitized_attrs: + print("\nConverted attributes:") + for key, value in sanitized_attrs.items(): + print(f" {key}: {type(value)} -> {value}") + + return sanitized_obj + def make_attrs_netcdf_safe(xr_obj): + """ + Make the attributes of an xarray object safe for NetCDF by JSON-encoding + dictionaries and other complex data types. + + Parameters: + xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to process. + + Returns: + xarray.DataArray or xarray.Dataset: A copy of the input object with NetCDF-safe attributes. + """ + def encode_complex(value): + """ + Encode complex data types (like dicts) into JSON strings. + """ + if isinstance(value, (dict, list, tuple)): + try: + # Convert to a JSON string + return json.dumps(value) + except (TypeError, OverflowError) as e: + # Handle unexpected cases gracefully + print(f"Error encoding attribute value: {value} ({e})") + return None + return value + + sanitized_obj = xr_obj.copy() + encoded_attrs = {} + + for key, value in sanitized_obj.attrs.items(): + encoded_value = encode_complex(value) + if encoded_value is not None: + encoded_attrs[key] = encoded_value + else: + print(f"Dropping unsupported attribute: {key} -> {value}") + + sanitized_obj.attrs = encoded_attrs + + return sanitized_obj # - This was copied from the Toney group contribution for GIWAXS. def saveZarr(self, filename, mode: str = 'w'): @@ -50,9 +148,31 @@ def saveZarr(self, filename, mode: str = 'w'): """ da = self._obj ds = da.to_dataset(name='DA') + ds = self.sanitize_attrs(ds) file_path = pathlib.Path(filename) ds.to_zarr(file_path, mode=mode) - + def saveNetCDF(self, filename): + """ + Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix. + + Parameters: + da (xr.DataArray): The DataArray to be saved. + base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. + prefix (str): The prefix to use for the file name. + suffix (str): The suffix to use for the file name. + mode (str): The mode to use when saving the file. Default is 'w'. + """ + da = self._obj + # sanitize attrs and make netcdf safe by converting dicts to json strings + da = self.sanitize_attrs(da) + da = self.make_attrs_netcdf_safe(da) + # unstack any multiindexes on the array + if hasattr(da, "indexes"): + multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xarray.core.indexes.MultiIndex)] + da = da.unstack(multiindexes) if multiindexes else da + file_path = pathlib.Path(filename) + da.to_netcdf(file_path) + def saveNexus(self,fileName,compression=5): data = self._obj timestamp = datetime.datetime.now() @@ -309,4 +429,4 @@ def _make_coords(f): else: coords[axes[n]] = f['entry']['sasdata'][axis] - return coords \ No newline at end of file + return coords From 11fd56004e20aa2f252f9a992faeba5179d355c0 Mon Sep 17 00:00:00 2001 From: Peter Beaucage Date: Thu, 23 Jan 2025 12:48:47 -0500 Subject: [PATCH 2/2] Typo fix and add unstacking to saveZarr() --- src/PyHyperScattering/FileIO.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/PyHyperScattering/FileIO.py b/src/PyHyperScattering/FileIO.py index 9fd705ba..fbb36ca0 100644 --- a/src/PyHyperScattering/FileIO.py +++ b/src/PyHyperScattering/FileIO.py @@ -136,21 +136,25 @@ def encode_complex(value): # - This was copied from the Toney group contribution for GIWAXS. def saveZarr(self, filename, mode: str = 'w'): - """ - Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. - + """ + Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. Parameters: - da (xr.DataArray): The DataArray to be saved. - base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. - prefix (str): The prefix to use for the file name. - suffix (str): The suffix to use for the file name. - mode (str): The mode to use when saving the file. Default is 'w'. - """ - da = self._obj - ds = da.to_dataset(name='DA') - ds = self.sanitize_attrs(ds) - file_path = pathlib.Path(filename) - ds.to_zarr(file_path, mode=mode) + da (xr.DataArray): The DataArray to be saved. + base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. + prefix (str): The prefix to use for the file name. + suffix (str): The suffix to use for the file name. + mode (str): The mode to use when saving the file. Default is 'w'. + """ + da = self._obj + ds = da.to_dataset(name='DA') + ds = self.sanitize_attrs(ds) + # unstack any multiindexes on the array + if hasattr(da, "indexes"): + multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] + da = da.unstack(multiindexes) if multiindexes else da + + file_path = pathlib.Path(filename) + ds.to_zarr(file_path, mode=mode) def saveNetCDF(self, filename): """ Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix. @@ -168,7 +172,7 @@ def saveNetCDF(self, filename): da = self.make_attrs_netcdf_safe(da) # unstack any multiindexes on the array if hasattr(da, "indexes"): - multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xarray.core.indexes.MultiIndex)] + multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] da = da.unstack(multiindexes) if multiindexes else da file_path = pathlib.Path(filename) da.to_netcdf(file_path)