Skip to content

Commit

Permalink
Add optional dependency mechanism and roll out across codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
pbeaucage committed Feb 15, 2025
1 parent 0405081 commit f2f4540
Show file tree
Hide file tree
Showing 15 changed files with 483 additions and 235 deletions.
33 changes: 21 additions & 12 deletions src/PyHyperScattering/ALS11012RSoXSLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@
import warnings
import re
import PyHyperScattering
from .optional_dependencies import requires_optional, check_optional_dependency, warn_if_missing

try:
# Check for optional dependencies
HAS_ASTROPY = check_optional_dependency('astropy')
if HAS_ASTROPY:
from astropy.io import fits
except ImportError:
warnings.warn('Could not import astropy.io.fits, needed for ALS 11.0.1.2 RSoXS loading. Is this dependency installed?',stacklevel=2)
else:
warn_if_missing('astropy')


class ALS11012RSoXSLoader(FileLoader):
'''
Loader for FITS files from the ALS 11.0.1.2 RSoXS instrument
Additional requirement: astropy, for FITS file loader
Note: This loader requires the 'astropy' package for reading FITS files.
If not installed, the loader will not be functional.
Usage is mainly via the inherited function integrateImageStack from FileLoader
'''
file_ext = '(.*?).fits'
md_loading_is_quick = True


def __init__(self,corr_mode=None,user_corr_func=None,dark_pedestal=0,exposure_offset=0.002,dark_subtract=False,data_collected_after_mar2021=False,constant_md={}):
'''
Args:
Expand All @@ -38,6 +38,9 @@ def __init__(self,corr_mode=None,user_corr_func=None,dark_pedestal=0,exposure_of
data_collected_after_mar2021 (boolean, default False): if True, uses 'CCD Camera Shutter Inhibit' as the dark-indicator; if False, uses 'CCD Shutter Inhibit'
constant_md (dict): values to insert into every metadata load. Example: beamcenter_x, beamcenter_y, sdd to enable qx/qy loading.
'''
if not HAS_ASTROPY:
raise ImportError("The 'astropy' package is required for this loader to function. Please install it first.")

if corr_mode == None:
warnings.warn("Correction mode was not set, not performing *any* intensity corrections. Are you sure this is "+
"right? Set corr_mode to 'none' to suppress this warning.",stacklevel=2)
Expand All @@ -51,18 +54,21 @@ def __init__(self,corr_mode=None,user_corr_func=None,dark_pedestal=0,exposure_of
data_collected_after_mar2021 = False
else:
data_collected_after_mar2021 = True

if data_collected_after_mar2021:
self.shutter_inhibit = 'CCD Camera Shutter Inhibit'
else:
self.shutter_inhibit = 'CCD Shutter Inhibit'
self.dark_pedestal = dark_pedestal
self.user_corr_func = user_corr_func
self.dark_pedestal = dark_pedestal
self.exposure_offset = exposure_offset
self.darks = {}
self.constant_md = constant_md
self.dark_subtract = dark_subtract
self.data_collected_after_mar2021 = data_collected_after_mar2021
self.constant_md = constant_md
self.darks = {}

@requires_optional('astropy')
def loadDarks(self,basepath,dark_base_name):
'''
Load a series of dark images as a function of exposure time, to be subtracted from subsequently-loaded data.
Expand All @@ -81,6 +87,7 @@ def loadDarks(self,basepath,dark_base_name):
self.darks[exptime] = darkimage[2].data


@requires_optional('astropy')
def loadSampleSpecificDarks(self,basepath,file_filter='',file_skip='donotskip',md_filter={}):
'''
load darks matching a specific sample metadata
Expand Down Expand Up @@ -118,6 +125,8 @@ def loadSampleSpecificDarks(self,basepath,file_filter='',file_skip='donotskip',m
print(f'Loading dark for {md["EXPOSURE"]} from {file}')
exptime = md['EXPOSURE']
self.darks[exptime] = img

@requires_optional('astropy')
def loadSingleImage(self,filepath,coords=None,return_q=False,**kwargs):
'''
THIS IS A HELPER FUNCTION, mostly - should not be called directly unless you know what you are doing
Expand Down Expand Up @@ -177,6 +186,7 @@ def loadSingleImage(self,filepath,coords=None,return_q=False,**kwargs):
return xr.DataArray(img,dims=['qy','qx'],coords={'qy':qy,'qx':qx},attrs=headerdict)
return xr.DataArray(img,dims=['pix_x','pix_y'],attrs=headerdict)

@requires_optional('astropy')
def peekAtMd(self,file):
'''
load the header/metadata without opening the corresponding image
Expand Down Expand Up @@ -211,4 +221,3 @@ def normalizeMetadata(self,headerdict):
headerdict['det_th'] = round(headerdict['CCD Theta'],2)
headerdict.update(self.constant_md)
return headerdict

28 changes: 24 additions & 4 deletions src/PyHyperScattering/CMSGIWAXSLoader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import pathlib
import warnings
import fabio
from PIL import Image
from PyHyperScattering.FileLoader import FileLoader
import xarray as xr
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from .optional_dependencies import requires_optional, check_optional_dependency, warn_if_missing

# Check for optional dependencies
HAS_FABIO = check_optional_dependency('fabio')
HAS_PIL = check_optional_dependency('PIL')

if HAS_FABIO:
import fabio
else:
warn_if_missing('fabio')

if HAS_PIL:
from PIL import Image
else:
warn_if_missing('PIL')


class CMSGIWAXSLoader(FileLoader):
"""
GIXS Data Loader Class | NSLS-II 11-BM (CMS)
Used to load single TIFF time-series TIFF GIWAXS images.
Note: This loader requires either the 'fabio' or 'PIL' package for reading image files.
At least one of these packages must be installed for the loader to function.
"""
def __init__(self, md_naming_scheme=[], root_folder=None, delim='_'):
"""
Expand All @@ -24,6 +41,8 @@ def __init__(self, md_naming_scheme=[], root_folder=None, delim='_'):
delim: delimeter value to split filename (default is underscore)
"""
if not (HAS_FABIO or HAS_PIL):
raise ImportError("Either 'fabio' or 'PIL' package is required for this loader to function. Please install at least one of them.")

self.md_naming_scheme = md_naming_scheme
if len(md_naming_scheme) == 0:
Expand All @@ -33,7 +52,7 @@ def __init__(self, md_naming_scheme=[], root_folder=None, delim='_'):
self.sample_dict = None
self.selected_series = []

def loadSingleImage(self, filepath,coords=None,return_q=False,image_slice=None,use_cached_md=False,**kwargs):
def loadSingleImage(self, filepath, coords=None, return_q=False, image_slice=None, use_cached_md=False, **kwargs):
"""
Loads a single xarray DataArray from a filepath to a raw TIFF
Expand All @@ -46,8 +65,9 @@ def loadSingleImage(self, filepath,coords=None,return_q=False,image_slice=None,u
- image_slice
- use_cached_md
Note:
This method will attempt to use fabio first, then fall back to PIL if fabio is not available.
"""

# Ensure that the path exists before continuing.
filepath = pathlib.Path(filepath)

Expand Down
32 changes: 25 additions & 7 deletions src/PyHyperScattering/ESRFID2Loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from PIL import Image
from PyHyperScattering.FileLoader import FileLoader
import os
import pathlib
Expand All @@ -7,17 +6,33 @@
import datetime
import warnings
import json
#from pyFAI import azimuthalIntegrator
import numpy as np
import h5py
import copy

import re
from .optional_dependencies import requires_optional, check_optional_dependency, warn_if_missing

# Check for optional dependencies
HAS_PIL = check_optional_dependency('PIL')
HAS_H5PY = check_optional_dependency('h5py')

if HAS_PIL:
from PIL import Image
else:
warn_if_missing('PIL')

if HAS_H5PY:
import h5py
else:
warn_if_missing('h5py')


class ESRFID2Loader(FileLoader):
'''
Loader for NEXUS files from the ID2 beamline at the ESRF
Note: This loader requires the following optional packages:
- 'h5py': Required for reading HDF5/NEXUS files
- 'PIL': Required for image processing
'''
file_ext = '(.*)eiger2(.*).h5'
md_loading_is_quick = True
Expand All @@ -28,8 +43,10 @@ def __init__(self,md_parse_dict=None,pedestal_value=1e-6,masked_pixel_fill=np.na
md_parse_dict (dict): keys should be names of underscore separated paramters in title. values should be regex to parse values
pedestal_value: value to add to image in order to deal with zero_counts
masked_pixel_fill: If None, pixels with value -10 will be converted to NaN. Otherwise, will be converted to this value
'''
if not HAS_H5PY:
raise ImportError("The 'h5py' package is required for this loader to function. Please install it first.")

if md_parse_dict is None:
self.md_regex = None
self.md_keys=None
Expand All @@ -44,11 +61,12 @@ def __init__(self,md_parse_dict=None,pedestal_value=1e-6,masked_pixel_fill=np.na
self.pedestal_value=pedestal_value
self.masked_pixel_fill = masked_pixel_fill
self.cached_md = None


@requires_optional('h5py')
def loadMd(self,filepath,split_on='_',keys=None):
return self.peekAtMd(filepath,split_on='_')

@requires_optional('h5py')
def peekAtMd(self,filepath,split_on='_',keys=None):
## Open h5 file and grab attributes
with h5py.File(str(filepath),'r') as h5:
Expand Down Expand Up @@ -105,6 +123,7 @@ def peekAtMd(self,filepath,split_on='_',keys=None):
return params


@requires_optional('h5py')
def loadSingleImage(self,filepath,coords=None,return_q=True,image_slice=None,use_cached_md=False,**kwargs):
'''
HELPER FUNCTION that loads a single image and returns an xarray with either pix_x / pix_y dimensions (if return_q == False) or qx / qy (if return_q == True)
Expand Down Expand Up @@ -171,4 +190,3 @@ def loadSingleImage(self,filepath,coords=None,return_q=True,image_slice=None,use
img += self.pedestal_value

return img

48 changes: 28 additions & 20 deletions src/PyHyperScattering/FileIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,36 @@
import numpy as np
import pickle
import math
import h5py
import pathlib
import datetime
import six
import PyHyperScattering
import pandas
import json

from collections import defaultdict
from . import _version
phs_version = _version.get_versions()['version']
from .optional_dependencies import requires_optional, check_optional_dependency, warn_if_missing

# Check for optional dependencies
HAS_H5PY = check_optional_dependency('h5py')

if HAS_H5PY:
import h5py
else:
warn_if_missing('h5py')

phs_version = _version.get_versions()['version']


@xr.register_dataset_accessor('fileio')
@xr.register_dataarray_accessor('fileio')
class FileIO:
"""
File I/O accessor for xarray DataArrays and Datasets.
Note: Some methods in this class require optional dependencies:
- 'h5py': Required for HDF5/NEXUS file operations
"""
def __init__(self,xr_obj):
self._obj=xr_obj

Expand All @@ -32,27 +45,22 @@ def __init__(self,xr_obj):
self._pyhyper_type = 'raw'

def savePickle(self,filename):
"""Save the DataArray/Dataset as a pickle file."""
with open(filename, 'wb') as file:
pickle.dump(self._obj, file)

def saveZarr(self, filename, mode: str = 'w'):
"""
Save the DataArray as a .zarr file.
# - This was copied from the Toney group contribution for GIWAXS.
def saveZarr(self, filename, mode: str = 'w'):
"""
Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix.
Parameters:
filename (Union[str, pathlib.Path]): Path to save the .zarr file
mode (str): The mode to use when saving the file. Default is 'w'
"""
da = self._obj
da.to_zarr(filename, mode=mode)

Parameters:
da (xr.DataArray): The DataArray to be saved.
base_path (Union[str, pathlib.Path]): The base path to save the .zarr file.
prefix (str): The prefix to use for the file name.
suffix (str): The suffix to use for the file name.
mode (str): The mode to use when saving the file. Default is 'w'.
"""
da = self._obj
ds = da.to_dataset(name='DA')
file_path = pathlib.Path(filename)
ds.to_zarr(file_path, mode=mode)

@requires_optional('h5py')
def saveNexus(self,fileName,compression=5):
data = self._obj
timestamp = datetime.datetime.now()
Expand Down Expand Up @@ -266,7 +274,7 @@ def _unserialize_attrs(hdf,attrdict):
encoding.replace('strftime-',''))
else:
warnings.warn(f'Unknown phs_encoding {encoding} while loading {entry}. Possible version mismatch. Loading as string.',stacklevel=2)
attrdict[entry] = hdf[entry][()]
attrdict[entry] = hdf[entry][()]
except KeyError:
attrdict[entry] = hdf[entry][()]
return attrdict
Expand Down
Loading

0 comments on commit f2f4540

Please sign in to comment.