Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 173 additions & 15 deletions spectral_util/spec_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, band_names, geotransform=None, projection=None, glt=None, pre


class SpectralMetadata:
def __init__(self, wavelengths, fwhm, geotransform=None, projection=None, glt=None, pre_orthod=False, nodata_value=None):
def __init__(self, wavelengths, fwhm, geotransform=None, projection=None, glt=None, pre_orthod=False, nodata_value=None, band_names=None):
"""
Initializes the SpectralMetadata object.

Expand All @@ -66,6 +66,7 @@ def __init__(self, wavelengths, fwhm, geotransform=None, projection=None, glt=No
self.glt = glt
self.pre_orthod = False
self.nodata_value = nodata_value
self.band_names = band_names

if pre_orthod:
self.orthoable = False
Expand All @@ -92,14 +93,15 @@ def wl_index(self, wl, buffer=None):
return np.where(np.logical_and(self.wl >= wl - buffer, self.wl <= wl + buffer))


def load_data(input_file, lazy=True, load_glt=False, load_loc=False):
def load_data(input_file, lazy=True, load_glt=False, load_loc=False, mask_type=None, return_loc_from_l1b_rad_nc=False):
"""
Loads a file and extracts the spectral metadata and data.

Args:
input_file (str): Path to the input file.
lazy (bool, optional): If True, loads the data lazily. Defaults to True.
load_glt (bool, optional): If True, loads the glt for orthoing. Defaults to False.
return_loc_from_l1b_rad

Raises:
ValueError: If the file type is unknown.
Expand All @@ -116,7 +118,8 @@ def load_data(input_file, lazy=True, load_glt=False, load_loc=False):
if input_filename.endswith(('.hdr', '.dat', '.img')) or '.' not in input_filename:
return open_envi(input_file, lazy=lazy)
elif input_filename.endswith('.nc'):
return open_netcdf(input_file, lazy=lazy, load_glt=load_glt, load_loc=load_loc)
return open_netcdf(input_file, lazy=lazy, load_glt=load_glt, load_loc=load_loc,
mask_type=mask_type, return_loc_from_l1b_rad_nc=return_loc_from_l1b_rad_nc)
elif input_filename.endswith('.tif') or input_filename.endswith('.vrt'):
return open_tif(input_file, lazy=lazy)
else:
Expand Down Expand Up @@ -238,20 +241,26 @@ def open_envi(input_file, lazy=True):
else:
nodata_value = -9999 # set default

if 'band names' in imeta:
band_names = imeta['band names']
else:
band_names = 'None'

if 'coordinate system string' in imeta:
css = imeta['coordinate system string']
proj = css if type(css) == str else ','.join(css)
else:
proj = None

map_info, trans = None, None
if 'map info' in imeta:
map_info = imeta['map info'].split(',') if type(imeta['map info']) == str else imeta['map info']
rotation=0
for val in map_info:
if 'rotation=' in val:
rotation = float(val.replace('rotation=','').strip())
trans = [float(map_info[3]), float(map_info[5]), rotation, float(map_info[4]), rotation, -float(map_info[6])]
else:
map_info, trans = None, None
if imeta['map info'][0] != '':
map_info = imeta['map info'].split(',') if type(imeta['map info']) == str else imeta['map info']
rotation=0
for val in map_info:
if 'rotation=' in val:
rotation = float(val.replace('rotation=','').strip())
trans = [float(map_info[3]), float(map_info[5]), rotation, float(map_info[4]), rotation, -float(map_info[6])]

glt = None
if 'glt' in os.path.basename(input_file).lower():
Expand All @@ -262,7 +271,7 @@ def open_envi(input_file, lazy=True):
else:
rfl = ds.open_memmap(interleave='bip').copy()

meta = SpectralMetadata(wl, fwhm, nodata_value=nodata_value, geotransform=trans, projection=proj, glt=glt)
meta = SpectralMetadata(wl, fwhm, nodata_value=nodata_value, geotransform=trans, projection=proj, glt=glt, band_names=band_names)
return meta, rfl


Expand Down Expand Up @@ -295,7 +304,7 @@ def open_tif(input_file, lazy=False):
return meta, data


def open_netcdf(input_file, lazy=True, load_glt=False, load_loc=False):
def open_netcdf(input_file, lazy=True, load_glt=False, load_loc=False, mask_type=None, return_loc_from_l1b_rad_nc=None):
"""
Opens a NetCDF file and extracts the metadata and data.

Expand All @@ -311,11 +320,18 @@ def open_netcdf(input_file, lazy=True, load_glt=False, load_loc=False):
"""
input_filename = os.path.basename(input_file)
if 'EMIT' in input_filename and 'RAD' in input_filename:
return open_emit_rdn(input_file, lazy=lazy, load_glt=load_glt)
if return_loc_from_l1b_rad_nc:
return open_loc_l1b_rad_nc(input_file, lazy=lazy, load_glt=load_glt)
else:
return open_emit_rdn(input_file, lazy=lazy, load_glt=load_glt)
elif ('emit' in input_filename.lower() and 'obs' in input_filename.lower()):
return open_emit_obs_nc(input_file, lazy=lazy, load_glt=load_glt, load_loc=load_loc)
elif ('emit' in input_filename.lower() and 'l2a_mask' in input_filename.lower()):
return open_emit_l2a_mask_nc(input_file, mask_type, lazy=lazy, load_glt=load_glt, load_loc=load_loc)
elif 'AV3' in input_filename and 'RFL' in input_filename:
return open_airborne_rfl(input_file, lazy=lazy)
elif 'AV3' in input_filename and 'BANDMASK' in input_filename:
return open_av3_bandmask_nc(input_file, lazy=lazy)
elif 'AV3' in input_filename and 'RDN' in input_filename:
return open_airborne_rdn(input_file, lazy=lazy)
elif ('av3' in input_filename.lower() or 'ang' in input_filename.lower()) and 'OBS' in input_filename:
Expand Down Expand Up @@ -360,6 +376,115 @@ def open_emit_rdn(input_file, lazy=True, load_glt=False):

return meta, rdn

def open_loc_l1b_rad_nc(input_file, lazy=True, load_glt=False, load_loc=False):
"""
Opens an EMIT L2A_MASK NetCDF file and extracts the spectral metadata and mask data.

Args:
input_file (str): Path to the NetCDF file.
lazy (bool, optional): Ignored

Returns:
tuple: A tuple containing:
- GenericGeoMetadata: An object containing the band names
- numpy.ndarray or netCDF4.Variable: The mask data
"""
ds = nc.Dataset(input_file)
proj = ds.spatial_ref
trans = ds.geotransform

nodata_value = float(ds['location']['lon']._FillValue)
glt = None
if load_glt:
glt = np.stack([ds['location']['glt_x'][:],ds['location']['glt_y'][:]],axis=-1)
loc = None
if load_loc:
loc = np.stack([ds['location']['lon'][:],ds['location']['lat'][:]],axis=-1)

# Don't have a good solution for lazy here, temporarily ignoring...
if lazy:
logging.warning("Lazy loading not supported for L1B RAD LOC data.")

loc_plus_elev = np.stack([ds['location']['lat'], ds['location']['lon'], ds['location']['elev']], axis = -1)

meta = GenericGeoMetadata([ds['location']['lat'].long_name, ds['location']['lon'].long_name, ds['location']['elev'].long_name],
trans, proj, glt=glt, pre_orthod=True, nodata_value=nodata_value, loc=loc)

return meta, loc_plus_elev

def open_av3_bandmask_nc(input_file, lazy=True, load_glt=False, load_loc=False):
"""
Opens an EMIT L2A_MASK NetCDF file and extracts the spectral metadata and mask data.

Args:
input_file (str): Path to the NetCDF file.
lazy (bool, optional): Ignored

Returns:
tuple: A tuple containing:
- GenericGeoMetadata: An object containing the band names
- numpy.ndarray or netCDF4.Variable: The mask data
"""
ds = nc.Dataset(input_file)

nodata_value = float(ds['band_mask']._FillValue)

# Don't have a good solution for lazy here, temporarily ignoring...
if lazy:
logging.warning("Lazy loading not supported for BANDMASK data.")

mask = np.array(ds['band_mask'][...])

meta = GenericGeoMetadata(None, None, None, glt=None, pre_orthod=True, nodata_value=nodata_value, loc=None)

return meta, mask.transpose([1,2,0])

def open_emit_l2a_mask_nc(input_file, mask_type, lazy=True, load_glt=False, load_loc=False):
"""
Opens an EMIT L2A_MASK NetCDF file and extracts the spectral metadata and mask data.

Args:
input_file (str): Path to the NetCDF file.
mask_type (str): Mask type. Options are
'mask': L2A_MASK
'band_mask': L1B_BANDMASK
lazy (bool, optional): Ignored

Returns:
tuple: A tuple containing:
- GenericGeoMetadata: An object containing the band names
- numpy.ndarray or netCDF4.Variable: The mask data
"""
if not mask_type in ['mask', 'band_mask']:
raise ValueError(f"Invalid mask type {mask_type}. Must use either 'mask' or 'band_mask'")

ds = nc.Dataset(input_file)
proj = ds.spatial_ref
trans = ds.geotransform

if mask_type == 'mask':
mask_names = list(ds['sensor_band_parameters']['mask_bands'][...])
else:
mask_names = ['']

nodata_value = float(ds[mask_type]._FillValue)
glt = None
if load_glt:
glt = np.stack([ds['location']['glt_x'][:],ds['location']['glt_y'][:]],axis=-1)
loc = None
if load_loc:
loc = np.stack([ds['location']['lon'][:],ds['location']['lat'][:]],axis=-1)

# Don't have a good solution for lazy here, temporarily ignoring...
if lazy:
logging.warning("Lazy loading not supported for L2A mask data.")

mask = np.array(ds[mask_type][...])

meta = GenericGeoMetadata(mask_names, trans, proj, glt=glt, pre_orthod=True, nodata_value=nodata_value, loc=loc)

return meta, mask

def open_emit_obs_nc(input_file, lazy=True, load_glt=False, load_loc=False):
"""
Opens an EMIT observation NetCDF file and extracts the spectral metadata and obs data.
Expand Down Expand Up @@ -587,8 +712,41 @@ def create_envi_file(output_file, data_shape, meta, dtype=np.dtype(np.float32)):
if 'fwhm' in meta.__dict__ and meta.fwhm is not None:
header['fwhm'] = '{ ' + ', '.join(map(str, meta.fwhm)) + ' }'
if 'band_names' in meta.__dict__ and meta.band_names is not None:
header['band names'] = '{ ' + ', '.join(meta.band_names) + ' }'
if isinstance(meta.band_names, str):
header['band names'] = '{ ' + meta.band_names + ' }'
elif isinstance(meta.band_names, list):
header['band names'] = '{ ' + ', '.join(meta.band_names) + ' }'
else:
# Not sure what to do now, so just write it out as if it were a list
header['band names'] = '{ ' + ', '.join(meta.band_names) + ' }'

header['data ignore value'] = str(meta.nodata_value)

envi.write_envi_header(envi_header(output_file), header)

def write_geotiff(data, meta, output_filename):
"""
Creates a geotiff file with the given data and metadata.

Args:
data: data to write: nx, ny, nbands
meta (GenericGeoMetadata): The metadata
output_filename: Output file name (should include the .tif)
"""
write_data = data
if len(write_data.shape) == 2:
write_data = data.copy()[:,:,None]

driver = gdal.GetDriverByName('GTiff')
outDataset = driver.Create(output_filename,
write_data.shape[1], write_data.shape[0], write_data.shape[2],
gdal.GDT_Float32, options = ['COMPRESS=LZW'])

for i in range(write_data.shape[-1]):
outDataset.GetRasterBand(i+1).WriteArray(write_data[:,:,i])
outDataset.GetRasterBand(i+1).SetNoDataValue(-9999)

outDataset.SetProjection(meta.projection)
outDataset.SetGeoTransform(meta.geotransform)
outDataset.FlushCache() ##saves to disk!!
outDataset = None