diff --git a/src/PyHyperScattering/FileIO.py b/src/PyHyperScattering/FileIO.py index 816c716..fbb36ca 100644 --- a/src/PyHyperScattering/FileIO.py +++ b/src/PyHyperScattering/FileIO.py @@ -34,12 +34,130 @@ def __init__(self,xr_obj): def savePickle(self,filename): with open(filename, 'wb') as file: pickle.dump(self._obj, file) - + + def sanitize_attrs(xr_obj): + """ + Sanitize the attributes of an xarray object to make them JSON serializable, + handling deeply nested dictionaries, lists, and array-like objects. + + Parameters: + xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to sanitize. + + Returns: + xarray.DataArray or xarray.Dataset: A copy of the input object with sanitized attributes. + """ + def sanitize_value(value): + """Recursively sanitize a value to ensure JSON serializability.""" + if isinstance(value, datetime): + return value.isoformat() # Convert datetime to ISO 8601 string + elif isinstance(value, np.ndarray): + return value.tolist() # Convert numpy arrays to lists + elif hasattr(value, "__array__"): # Handles other array-like objects + return np.asarray(value).tolist() + elif isinstance(value, dict): + # Recursively sanitize dictionary values + return {k: sanitize_value(v) for k, v in value.items()} + elif isinstance(value, list): + # Recursively sanitize list elements + return [sanitize_value(v) for v in value] + else: + try: + # Check if the value can be serialized to JSON + json.dumps(value) + return value + except (TypeError, OverflowError): + return None # Mark non-serializable values as None + + sanitized_obj = xr_obj.copy() + sanitized_attrs = {} + dropped_attrs = {} + + for key, value in sanitized_obj.attrs.items(): + sanitized_value = sanitize_value(value) + if sanitized_value is not None: + sanitized_attrs[key] = sanitized_value + else: + dropped_attrs[key] = value + + sanitized_obj.attrs = sanitized_attrs + + # Print or log a summary of the sanitized attributes + if dropped_attrs: + print("Dropped non-serializable attributes:") + for key, value in dropped_attrs.items(): + print(f" {key}: {type(value)} - {value}") + else: + print("No attributes were dropped.") + + if sanitized_attrs: + print("\nConverted attributes:") + for key, value in sanitized_attrs.items(): + print(f" {key}: {type(value)} -> {value}") + + return sanitized_obj + def make_attrs_netcdf_safe(xr_obj): + """ + Make the attributes of an xarray object safe for NetCDF by JSON-encoding + dictionaries and other complex data types. + + Parameters: + xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to process. + + Returns: + xarray.DataArray or xarray.Dataset: A copy of the input object with NetCDF-safe attributes. + """ + def encode_complex(value): + """ + Encode complex data types (like dicts) into JSON strings. + """ + if isinstance(value, (dict, list, tuple)): + try: + # Convert to a JSON string + return json.dumps(value) + except (TypeError, OverflowError) as e: + # Handle unexpected cases gracefully + print(f"Error encoding attribute value: {value} ({e})") + return None + return value + + sanitized_obj = xr_obj.copy() + encoded_attrs = {} + + for key, value in sanitized_obj.attrs.items(): + encoded_value = encode_complex(value) + if encoded_value is not None: + encoded_attrs[key] = encoded_value + else: + print(f"Dropping unsupported attribute: {key} -> {value}") + + sanitized_obj.attrs = encoded_attrs + + return sanitized_obj # - This was copied from the Toney group contribution for GIWAXS. def saveZarr(self, filename, mode: str = 'w'): - """ - Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. + """ + Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. + Parameters: + da (xr.DataArray): The DataArray to be saved. + base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. + prefix (str): The prefix to use for the file name. + suffix (str): The suffix to use for the file name. + mode (str): The mode to use when saving the file. Default is 'w'. + """ + da = self._obj + ds = da.to_dataset(name='DA') + ds = self.sanitize_attrs(ds) + # unstack any multiindexes on the array + if hasattr(da, "indexes"): + multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] + da = da.unstack(multiindexes) if multiindexes else da + + file_path = pathlib.Path(filename) + ds.to_zarr(file_path, mode=mode) + def saveNetCDF(self, filename): + """ + Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix. Parameters: da (xr.DataArray): The DataArray to be saved. @@ -47,12 +165,18 @@ def saveZarr(self, filename, mode: str = 'w'): prefix (str): The prefix to use for the file name. suffix (str): The suffix to use for the file name. mode (str): The mode to use when saving the file. Default is 'w'. - """ - da = self._obj - ds = da.to_dataset(name='DA') - file_path = pathlib.Path(filename) - ds.to_zarr(file_path, mode=mode) - + """ + da = self._obj + # sanitize attrs and make netcdf safe by converting dicts to json strings + da = self.sanitize_attrs(da) + da = self.make_attrs_netcdf_safe(da) + # unstack any multiindexes on the array + if hasattr(da, "indexes"): + multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] + da = da.unstack(multiindexes) if multiindexes else da + file_path = pathlib.Path(filename) + da.to_netcdf(file_path) + def saveNexus(self,fileName,compression=5): data = self._obj timestamp = datetime.datetime.now() @@ -309,4 +433,4 @@ def _make_coords(f): else: coords[axes[n]] = f['entry']['sasdata'][axis] - return coords \ No newline at end of file + return coords