From c0be02694a9e7141b06375e03449f40dfac226aa Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 6 Feb 2025 12:49:42 -0500 Subject: [PATCH 1/4] Apply code style rules to wfss_contam --- .pre-commit-config.yaml | 1 - .ruff.toml | 4 +- jwst/wfss_contam/__init__.py | 2 + jwst/wfss_contam/disperse.py | 66 ++--- jwst/wfss_contam/observations.py | 277 +++++++++++++------- jwst/wfss_contam/sens1d.py | 29 +- jwst/wfss_contam/tests/test_observations.py | 8 +- jwst/wfss_contam/wfss_contam.py | 56 ++-- jwst/wfss_contam/wfss_contam_step.py | 38 +-- 9 files changed, 290 insertions(+), 191 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 50faa43aca..a7dd5fc31c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -91,7 +91,6 @@ repos: jwst/tso_photometry/.* | jwst/wavecorr/.* | jwst/wfs_combine/.* | - jwst/wfss_contam/.* | jwst/white_light/.* | jwst/conftest.py | .*/tests/.* | diff --git a/.ruff.toml b/.ruff.toml index 70856c1166..f431bd3ddc 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -87,7 +87,7 @@ exclude = [ # "jwst/tweakreg/**.py", "jwst/wavecorr/**.py", "jwst/wfs_combine/**.py", - "jwst/wfss_contam/**.py", + # "jwst/wfss_contam/**.py", "jwst/white_light/**.py", ] @@ -217,5 +217,5 @@ ignore-fully-untyped = true # Turn of annotation checking for fully untyped cod # "jwst/tweakreg/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/wavecorr/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/wfs_combine/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] -"jwst/wfss_contam/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] +# "jwst/wfss_contam/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/white_light/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] diff --git a/jwst/wfss_contam/__init__.py b/jwst/wfss_contam/__init__.py index 9b0cbf422f..dd2ac2fdc9 100644 --- a/jwst/wfss_contam/__init__.py +++ b/jwst/wfss_contam/__init__.py @@ -1,3 +1,5 @@ +"""Decontaminate WFSS data.""" + from .wfss_contam_step import WfssContamStep __all__ = ["WfssContamStep"] diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index ef5476215a..04f9c65da0 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -7,13 +7,28 @@ from .sens1d import create_1d_sens -def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, - oversample_factor=2, extrapolate_sed=False, xoffset=0, - yoffset=0): +def dispersed_pixel( + x0, + y0, + width, + height, + lams, + flxs, + order, + wmin, + wmax, + sens_waves, + sens_resp, + seg_wcs, + grism_wcs, + naxis, + oversample_factor=2, + extrapolate_sed=False, + xoffset=0, + yoffset=0, +): """ - This function take a list of pixels and disperses them using the information contained - in the grism image WCS object and returns a list of dispersed pixels and fluxes. + Transform pixels from direct image to dispersed frame using the grism image WCS object. Parameters ---------- @@ -47,22 +62,20 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, The WCS object of the segmentation map. grism_wcs : WCS object The WCS object of the grism image. - ID : int - The ID of the object to which the pixel belongs. naxis : tuple Dimensions (shape) of grism image into which pixels are dispersed. oversample_factor : int The amount of oversampling required above that of the input spectra or natural dispersion, whichever is smaller. Default=2. extrapolate_sed : bool - Whether to allow for the SED of the object to be extrapolated when it does not fully cover the - needed wavelength range. Default if False. + Whether to allow for the SED of the object to be extrapolated when + it does not fully cover the needed wavelength range. Default if False. xoffset : int - Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to - full frame) + Pixel offset to apply when computing the dispersion (accounts for offset + from source cutout to full frame) yoffset : int - Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to - full frame) + Pixel offset to apply when computing the dispersion (accounts for offset + from source cutout to full frame) Returns ------- @@ -71,18 +84,16 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, ys : array 1D array of dispersed pixel y-coordinates areas : array - 1D array of the areas of the incident pixel that when dispersed falls on each dispersed pixel + 1D array of the areas of the incident pixel that, + when dispersed, falls on each dispersed pixel lams : array 1D array of the wavelengths of each dispersed pixel counts : array 1D array of counts for each dispersed pixel - ID : int - The source ID. Returned for bookkeeping convenience. """ - # Setup the transforms we need from the input WCS objects - sky_to_imgxy = grism_wcs.get_transform('world', 'detector') - imgxy_to_grismxy = grism_wcs.get_transform('detector', 'grism_detector') + sky_to_imgxy = grism_wcs.get_transform("world", "detector") + imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector") # Setup function for retrieving flux values at each dispersed wavelength if len(lams) > 1: @@ -90,13 +101,13 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, # we have the option to extrapolate the fluxes outside the # wavelength range of the direct images if extrapolate_sed is False: - flux = interp1d(lams, flxs, fill_value=0., bounds_error=False) + flux = interp1d(lams, flxs, fill_value=0.0, bounds_error=False) else: flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) else: # If we only have flux from one lambda, just use that # single flux value at all wavelengths - def flux(x): + def flux(_x): return flxs[0] # Get x/y positions in the grism image corresponding to wmin and wmax: @@ -144,12 +155,7 @@ def flux(x): # Compute arrays of dispersed pixel locations and areas padding = 1 - xs, ys, areas, index = get_clipped_pixels( - x0s, y0s, - padding, - naxis[0], naxis[1], - width, height - ) + xs, ys, areas, index = get_clipped_pixels(x0s, y0s, padding, naxis[0], naxis[1], width, height) lams = np.take(lambdas, index) # If results give no dispersed pixels, return null result @@ -166,6 +172,6 @@ def flux(x): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") counts = flux(lams) * areas / (sens * oversample_factor) - counts[no_cal] = 0. # set to zero where no flux cal info available + counts[no_cal] = 0.0 # set to zero where no flux cal info available - return xs, ys, areas, lams, counts, ID + return xs, ys, areas, lams, counts diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index f11108b0b9..043454d0ce 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -17,9 +17,11 @@ log.setLevel(logging.DEBUG) -def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclude_percentile=30.0): +def background_subtract( + data, box_size=None, filter_size=(3, 3), sigma=3.0, exclude_percentile=30.0 +): """ - Simple astropy background subtraction + Apply a simple astropy background subtraction. Parameters ---------- @@ -47,25 +49,42 @@ def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclu in a previous version. """ if box_size is None: - box_size = (int(data.shape[0]/5), int(data.shape[1]/5)) + box_size = (int(data.shape[0] / 5), int(data.shape[1] / 5)) sigma_clip = SigmaClip(sigma=sigma) bkg_estimator = MedianBackground() - bkg = Background2D(data, box_size, filter_size=filter_size, - sigma_clip=sigma_clip, bkg_estimator=bkg_estimator, - exclude_percentile=exclude_percentile) + bkg = Background2D( + data, + box_size, + filter_size=filter_size, + sigma_clip=sigma_clip, + bkg_estimator=bkg_estimator, + exclude_percentile=exclude_percentile, + ) return data - bkg.background class Observation: - """This class defines an actual observation. It is tied to a single grism image.""" - - def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, - sed_file=None, extrapolate_sed=False, - boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1): - + """Define an observation leading to a single grism image.""" + + def __init__( + self, + direct_images, + segmap_model, + grism_wcs, + filter_name, + source_id=0, + sed_file=None, + extrapolate_sed=False, + boundaries=None, + offsets=None, + renormalize=True, + max_cpu=1, + ): """ - Initialize all data and metadata for a given observation. Creates lists of + Initialize all data and metadata for a given observation. + + Creates lists of direct image pixel values for selected objects. Parameters @@ -76,9 +95,9 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, Segmentation map model grism_wcs : gwcs object WCS object from grism image - filter : str + filter_name : str Filter name - ID : int + source_id : int ID of source to process. If zero, all sources processed. sed_file : str Name of Spectral Energy Distribution (SED) file containing datasets matching @@ -87,21 +106,26 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, Flag indicating whether to extrapolate wavelength range of SED boundaries : tuple Start/Stop coordinates of the FOV within the larger seed image. + offsets : tuple + Offset values for x and y axes renormalize : bool Flag indicating whether to renormalize SED's max_cpu : int Max number of cpu's to use when multiprocessing """ - + if boundaries is None: + boundaries = () + if offsets is None: + offsets = (0, 0) # Load all the info for this grism mode self.seg_wcs = segmap_model.meta.wcs self.grism_wcs = grism_wcs - self.ID = ID - self.IDs = [] + self.source_id = source_id + self.source_ids = [] self.dir_image_names = direct_images self.seg = segmap_model.data - self.filter = filter - self.sed_file = sed_file # should always be NONE for baseline pipeline (use flat SED) + self.filter = filter_name + self.sed_file = sed_file # should always be NONE for baseline pipeline (use flat SED) self.cache = False self.renormalize = renormalize self.max_cpu = max_cpu @@ -129,37 +153,35 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, self.create_pixel_list() def create_pixel_list(self): - # Create a list of pixels to be dispersed, grouped per object ID. - - if self.ID == 0: - # When ID=0, all sources in the segmentation map are processed. + """Create a list of pixels to be dispersed, grouped per object ID.""" + if self.source_id == 0: + # When source_id=0, all sources in the segmentation map are processed. # This creates a huge list of all x,y pixel indices that have non-zero values # in the seg map, sorted by those indices belonging to a particular source ID. self.xs = [] self.ys = [] - all_IDs = np.array(list(set(np.ravel(self.seg)))) - all_IDs = all_IDs[all_IDs > 0] - self.IDs = all_IDs - log.info(f"Loading {len(all_IDs)} sources from segmentation map") - for ID in all_IDs: - ys, xs = np.nonzero(self.seg == ID) + all_ids = np.array(list(set(np.ravel(self.seg)))) + all_ids = all_ids[all_ids > 0] + self.source_ids = all_ids + log.info(f"Loading {len(all_ids)} sources from segmentation map") + for source_id in all_ids: + ys, xs = np.nonzero(self.seg == source_id) if len(xs) > 0 and len(ys) > 0: self.xs.append(xs) self.ys.append(ys) else: # Process only the given source ID - log.info(f"Loading source {self.ID} from segmentation map") - ys, xs = np.nonzero(self.seg == self.ID) + log.info(f"Loading source {self.source_id} from segmentation map") + ys, xs = np.nonzero(self.seg == self.source_id) if len(xs) > 0 and len(ys) > 0: self.xs = [xs] self.ys = [ys] - self.IDs = [self.ID] + self.source_ids = [self.source_id] # Populate lists of direct image flux values for the sources. self.fluxes = {} for dir_image_name in self.dir_image_names: - log.info(f"Using direct image {dir_image_name}") with datamodels.open(dir_image_name) as model: dimage = model.data @@ -169,19 +191,19 @@ def create_pixel_list(self): # Default pipeline will use sed_file=None, so we need to compute # photometry values that used to come from HST-style header keywords. # Set pivlam, in units of microns, based on filter name. - pivlam = float(self.filter[1:4]) / 100. + pivlam = float(self.filter[1:4]) / 100.0 # Use pixel fluxes from the direct image. self.fluxes[pivlam] = [] - for i in range(len(self.IDs)): + for i in range(len(self.source_ids)): # This loads lists of pixel flux values for each source # from the direct image self.fluxes[pivlam].append(dimage[self.ys[i], self.xs[i]]) else: # Use an SED file. Need to normalize the object stamps. - for ID in self.IDs: - vg = self.seg == ID + for source_id in self.source_ids: + vg = self.seg == source_id dnew = dimage if self.renormalize: sum_seg = np.sum(dimage[vg]) # But normalize by the whole flux @@ -191,13 +213,12 @@ def create_pixel_list(self): log.debug("not renormalizing sources to unity") self.fluxes["sed"] = [] - for i in range(len(self.IDs)): + for i in range(len(self.source_ids)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): """ - Compute dispersed pixel values for all sources identified in - the segmentation map. + Compute dispersed pixel values for all sources identified in the segmentation map. Parameters ---------- @@ -220,25 +241,24 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): # Initialize the simulated dispersed image self.simulated_image = np.zeros(self.dims, float) - # Loop over all source ID's from segmentation map - for i in range(len(self.IDs)): + # Loop over all source IDs from segmentation map + for i in range(len(self.source_ids)): if self.cache: self.cached_object[i] = {} - self.cached_object[i]['x'] = [] - self.cached_object[i]['y'] = [] - self.cached_object[i]['f'] = [] - self.cached_object[i]['w'] = [] - self.cached_object[i]['minx'] = [] - self.cached_object[i]['maxx'] = [] - self.cached_object[i]['miny'] = [] - self.cached_object[i]['maxy'] = [] + self.cached_object[i]["x"] = [] + self.cached_object[i]["y"] = [] + self.cached_object[i]["f"] = [] + self.cached_object[i]["w"] = [] + self.cached_object[i]["minx"] = [] + self.cached_object[i]["maxx"] = [] + self.cached_object[i]["miny"] = [] + self.cached_object[i]["maxy"] = [] self.disperse_chunk(i, order, wmin, wmax, sens_waves, sens_resp) def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): """ - Method that computes dispersion for a single source. - To be called after create_pixel_list(). + Compute dispersion for a single source; to be called after create_pixel_list(). Parameters ---------- @@ -254,9 +274,13 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): Wavelength array from photom reference file sens_resp : float array Response (flux calibration) array from photom reference file - """ - sid = int(self.IDs[c]) + Returns + ------- + np.ndarray + 2D dispersed image for this source + """ + sid = int(self.source_ids[c]) self.order = order self.wmin = wmin self.wmax = wmax @@ -268,10 +292,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # Loop over all pixels in list for object "c" log.debug(f"source contains {len(self.xs[c])} pixels") for i in range(len(self.xs[c])): - - # Here "i" and "ID" are just indexes into the pixel list for the object - # being processed, as opposed to the ID number of the object itself - ID = i + # Here "i" just indexes the pixel list for the object being processed # xc, yc are the coordinates of the central pixel of the group # of pixels surrounding the direct image pixel index @@ -288,15 +309,39 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # "fluxes" is the array of pixel values from the direct image(s). # For the simple case of 1 combined direct image, this contains a # a single value (just like "lams"). - fluxes, lams = map(np.array, zip(*[ - (self.fluxes[lm][c][i], lm) for lm in sorted(self.fluxes.keys()) - if self.fluxes[lm][c][i] != 0 - ])) - - pars_i = (xc, yc, width, height, lams, fluxes, self.order, - self.wmin, self.wmax, self.sens_waves, self.sens_resp, - self.seg_wcs, self.grism_wcs, ID, self.dims[::-1], 2, - self.extrapolate_sed, self.xoffset, self.yoffset) + fluxes, lams = map( + np.array, + zip( + *[ + (self.fluxes[lm][c][i], lm) + for lm in sorted(self.fluxes.keys()) + if self.fluxes[lm][c][i] != 0 + ], + strict=True, + ), + ) + + pars_i = ( + xc, + yc, + width, + height, + lams, + fluxes, + self.order, + self.wmin, + self.wmax, + self.sens_waves, + self.sens_resp, + self.seg_wcs, + self.grism_wcs, + 0, + self.dims[::-1], + 2, + self.extrapolate_sed, + self.xoffset, + self.yoffset, + ) pars.append(pars_i) # now have full pars list for all pixels for this object @@ -331,42 +376,79 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): maxx = int(max(x)) miny = int(min(y)) maxy = int(max(y)) - a = sparse.coo_matrix((f, (y - miny, x - minx)), - shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + a = sparse.coo_matrix( + (f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1) + ).toarray() # Accumulate results into simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a - this_object[miny:maxy + 1, minx:maxx + 1] += a + self.simulated_image[miny : maxy + 1, minx : maxx + 1] += a + this_object[miny : maxy + 1, minx : maxx + 1] += a if self.cache: - self.cached_object[c]['x'].append(x) - self.cached_object[c]['y'].append(y) - self.cached_object[c]['f'].append(f) - self.cached_object[c]['w'].append(w) - self.cached_object[c]['minx'].append(minx) - self.cached_object[c]['maxx'].append(maxx) - self.cached_object[c]['miny'].append(miny) - self.cached_object[c]['maxy'].append(maxy) + self.cached_object[c]["x"].append(x) + self.cached_object[c]["y"].append(y) + self.cached_object[c]["f"].append(f) + self.cached_object[c]["w"].append(w) + self.cached_object[c]["minx"].append(minx) + self.cached_object[c]["maxx"].append(maxx) + self.cached_object[c]["miny"].append(miny) + self.cached_object[c]["maxy"].append(maxy) time2 = time.time() - log.debug(f"Elapsed time {time2-time1} sec") + log.debug(f"Elapsed time {time2 - time1} sec") return this_object def disperse_all_from_cache(self, trans=None): + """ + Compute dispersed pixel values for all sources identified in the segmentation map. + + Load data from cache where available. Currently not used. + + Parameters + ---------- + trans : function + Transmission function to apply to the flux values + + Returns + ------- + np.ndarray + 2D dispersed image for this source + + Notes + ----- + The return value of `this_object` appears to be a bug. + However, this is currently not used, and if the INS team wants to re-enable + caching, all functions here need updating anyway, so not fixing at this time. + """ if not self.cache: return self.simulated_image = np.zeros(self.dims, float) - for i in range(len(self.IDs)): + for i in range(len(self.source_ids)): this_object = self.disperse_chunk_from_cache(i, trans=trans) return this_object def disperse_chunk_from_cache(self, c, trans=None): - """Method that handles the dispersion. To be called after create_pixel_list()""" + """ + Compute dispersion for a single source; to be called after create_pixel_list(). + + Load data from cache where available. Currently not used. + Parameters + ---------- + c : int + Chunk (source) number to process + trans : function + Transmission function to apply to the flux values + + Returns + ------- + np.ndarray + 2D dispersed image for this source + """ if not self.cache: return @@ -378,28 +460,29 @@ def disperse_chunk_from_cache(self, c, trans=None): if trans is not None: log.debug("Applying a transmission function...") - for i in range(len(self.cached_object[c]['x'])): - x = self.cached_object[c]['x'][i] - y = self.cached_object[c]['y'][i] - f = self.cached_object[c]['f'][i] * 1. - w = self.cached_object[c]['w'][i] + for i in range(len(self.cached_object[c]["x"])): + x = self.cached_object[c]["x"][i] + y = self.cached_object[c]["y"][i] + f = self.cached_object[c]["f"][i] * 1.0 + w = self.cached_object[c]["w"][i] if trans is not None: f *= trans(w) - minx = self.cached_object[c]['minx'][i] - maxx = self.cached_object[c]['maxx'][i] - miny = self.cached_object[c]['miny'][i] - maxy = self.cached_object[c]['maxy'][i] + minx = self.cached_object[c]["minx"][i] + maxx = self.cached_object[c]["maxx"][i] + miny = self.cached_object[c]["miny"][i] + maxy = self.cached_object[c]["maxy"][i] - a = sparse.coo_matrix((f, (y - miny, x - minx)), - shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + a = sparse.coo_matrix( + (f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1) + ).toarray() # Accumulate the results into the simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a - this_object[miny:maxy + 1, minx:maxx + 1] += a + self.simulated_image[miny : maxy + 1, minx : maxx + 1] += a + this_object[miny : maxy + 1, minx : maxx + 1] += a time2 = time.time() - log.debug(f"Elapsed time {time2-time1} sec") + log.debug(f"Elapsed time {time2 - time1} sec") return this_object diff --git a/jwst/wfss_contam/sens1d.py b/jwst/wfss_contam/sens1d.py index a5544f9e2c..f78f507a65 100644 --- a/jwst/wfss_contam/sens1d.py +++ b/jwst/wfss_contam/sens1d.py @@ -3,20 +3,20 @@ from jwst.photom.photom import find_row import logging + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -def get_photom_data(phot_model, filter, pupil, order): +def get_photom_data(phot_model, filter_name, pupil, order): """ - Retrieves wavelength and response data from photom ref file - for the filter+pupil (grism) mode in use. + Retrieve wavelength and response data from photom ref file. Parameters ---------- phot_model : `jwst.datamodels.NrcWfssPhotomModel` or `jwst.datamodels.NisWfssPhotomModel` Photom ref file data model - filter : str + filter_name : str Filter value pupil : str Pupil value @@ -30,22 +30,21 @@ def get_photom_data(phot_model, filter, pupil, order): relresps : float array Array of response (flux calibration) values from the ref file """ - # Get the appropriate row of data from the reference table phot_table = phot_model.phot_table - fields_to_match = {'filter': filter, 'pupil': pupil, 'order': order} + fields_to_match = {"filter": filter_name, "pupil": pupil, "order": order} row = find_row(phot_table, fields_to_match) tabdata = phot_table[row] # Scalar conversion factor - scalar_conversion = tabdata['photmjsr'] # unit is MJy / sr + scalar_conversion = tabdata["photmjsr"] # unit is MJy / sr # Get the length of the relative response arrays in this row - nelem = tabdata['nelem'] + nelem = tabdata["nelem"] # Load the wavelength and relative response arrays - ref_waves = tabdata['wavelength'][:nelem] - relresps = scalar_conversion * tabdata['relresponse'][:nelem] + ref_waves = tabdata["wavelength"][:nelem] + relresps = scalar_conversion * tabdata["relresponse"][:nelem] # Make sure waves and relresps are in increasing wavelength order if not np.all(np.diff(ref_waves) > 0): @@ -54,17 +53,16 @@ def get_photom_data(phot_model, filter, pupil, order): relresps = relresps[index].copy() # Convert wavelengths from meters to microns, if necessary - microns_100 = 1.e-4 # 100 microns, in meters - if ref_waves.max() > 0. and ref_waves.max() < microns_100: - ref_waves *= 1.e+6 + microns_100 = 1.0e-4 # 100 microns, in meters + if ref_waves.max() > 0.0 and ref_waves.max() < microns_100: + ref_waves *= 1.0e6 return ref_waves, relresps def create_1d_sens(data_waves, ref_waves, relresps): """ - Create a 1D array of photometric conversion values based on - wavelengths per pixel and response as a function of wavelength. + Find photometric conversion values based on per-pixel wavelength-dependent response. Parameters ---------- @@ -82,7 +80,6 @@ def create_1d_sens(data_waves, ref_waves, relresps): no_cal : int array 1D mask indicating where no conversion is available """ - # Interpolate the photometric response values onto the # 1D wavelength grid of the data sens_1d = np.interp(data_waves, ref_waves, relresps, left=np.nan, right=np.nan) diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index 1cbbad0bc3..47b64969bc 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -97,15 +97,15 @@ def test_disperse_oversample_same_result(grism_wcs, segmentation_map): yoffset = 1000 - xs, ys, areas, lams_out, counts_1, ID = dispersed_pixel( + xs, ys, areas, lams_out, counts_1 = dispersed_pixel( x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + sens_waves, sens_resp, seg_wcs, grism_wcs, naxis, oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) - xs, ys, areas, lams_out, counts_3, ID = dispersed_pixel( + xs, ys, areas, lams_out, counts_3 = dispersed_pixel( x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + sens_waves, sens_resp, seg_wcs, grism_wcs, naxis, oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 841538442f..6fae4b53e9 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,6 +1,5 @@ -# -# Top level module for WFSS contamination correction. -# +"""Top-level module for WFSS contamination correction.""" + import logging import multiprocessing import numpy as np @@ -16,7 +15,7 @@ def contam_corr(input_model, waverange, photom, max_cores): """ - The main WFSS contamination correction function + Correct contamination in WFSS spectral cutouts. Parameters ---------- @@ -26,7 +25,7 @@ def contam_corr(input_model, waverange, photom, max_cores): Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model - max_cores : string + max_cores : str Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable values are 'quarter', 'half', and 'all', which indicate @@ -41,18 +40,17 @@ def contam_corr(input_model, waverange, photom, max_cores): Full-frame simulated image of the grism exposure contam_model : `~jwst.datamodels.MultiSlitModel` Contamination estimate images for each source slit - """ # Determine number of cpu's to use for multi-processing - if max_cores == 'none': + if max_cores == "none": ncpus = 1 else: num_cores = multiprocessing.cpu_count() - if max_cores == 'quarter': + if max_cores == "quarter": ncpus = num_cores // 4 or 1 - elif max_cores == 'half': + elif max_cores == "half": ncpus = num_cores // 2 or 1 - elif max_cores == 'all': + elif max_cores == "all": ncpus = num_cores else: ncpus = 1 @@ -93,7 +91,7 @@ def contam_corr(input_model, waverange, photom, max_cores): # the opposite. It has gratings in the FILTER wheel and filters in the # PUPIL wheel. So when processing NIRISS grism exposures the name of # filter needs to come from the PUPIL keyword value. - if input_model.meta.instrument.name == 'NIRISS': + if input_model.meta.instrument.name == "NIRISS": filter_name = pupil_kwd else: filter_name = filter_kwd @@ -108,20 +106,27 @@ def contam_corr(input_model, waverange, photom, max_cores): wmin[order] = wavelength_range[order][0] wmax[order] = wavelength_range[order][1] # Load the sensitivity (inverse flux cal) data for this mode and order - sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order) + sens_waves[order], sens_response[order] = get_photom_data( + photom, filter_kwd, pupil_kwd, order + ) log.debug(f"wmin={wmin}, wmax={wmax}") # Initialize the simulated image object simul_all = None - obs = Observation(image_names, seg_model, grism_wcs, filter_name, - boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus) + obs = Observation( + image_names, + seg_model, + grism_wcs, + filter_name, + boundaries=[0, 2047, 0, 2047], + offsets=[xoffset, yoffset], + max_cpu=ncpus, + ) # Create simulated grism image for each order and sum them up for order in spec_orders: - log.info(f"Creating full simulated grism image for order {order}") - obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], - sens_response[order]) + obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) # Accumulate result for this order into the combined image if simul_all is None: @@ -139,15 +144,15 @@ def contam_corr(input_model, waverange, photom, max_cores): contam_model.update(input_model) slits = [] for slit in output_model.slits: - # Create simulated spectrum for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order - chunk = np.where(obs.IDs == sid)[0][0] # find chunk for this source + chunk = np.where(obs.source_ids == sid)[0][0] # find chunk for this source obs.simulated_image = np.zeros(obs.dims) - obs.disperse_chunk(chunk, order, wmin[order], wmax[order], - sens_waves[order], sens_response[order]) + obs.disperse_chunk( + chunk, order, wmin[order], wmax[order], sens_waves[order], sens_response[order] + ) this_source = obs.simulated_image # Contamination estimate is full simulated image minus this source @@ -157,7 +162,7 @@ def contam_corr(input_model, waverange, photom, max_cores): # of the source slit x1 = slit.xstart - 1 y1 = slit.ystart - 1 - cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] + cutout = contam[y1 : y1 + slit.ysize, x1 : x1 + slit.xsize] new_slit = datamodels.SlitModel(data=cutout) copy_slit_info(slit, new_slit) slits.append(new_slit) @@ -169,14 +174,14 @@ def contam_corr(input_model, waverange, photom, max_cores): contam_model.slits.extend(slits) # Set the step status to COMPLETE - output_model.meta.cal_step.wfss_contam = 'COMPLETE' + output_model.meta.cal_step.wfss_contam = "COMPLETE" return output_model, simul_model, contam_model def copy_slit_info(input_slit, output_slit): - - """Copy meta info from one slit to another. + """ + Copy meta info from one slit to another. Parameters ---------- @@ -185,7 +190,6 @@ def copy_slit_info(input_slit, output_slit): output_slit : SlitModel Output slit model to which slit-specific info will be copied - """ output_slit.name = input_slit.name output_slit.xstart = input_slit.xstart diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index 66b34ebf5c..7f925ea134 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -9,9 +9,7 @@ class WfssContamStep(Step): - """ - This Step performs contamination correction of WFSS spectra. - """ + """Perform contamination correction of WFSS spectra.""" class_alias = "wfss_contam" @@ -20,30 +18,40 @@ class WfssContamStep(Step): save_contam_images = boolean(default=False) # Save source contam estimates maximum_cores = option('none', 'quarter', 'half', 'all', default='none') skip = boolean(default=True) - """ # noqa: E501 + """ # noqa: E501 - reference_file_types = ['photom', 'wavelengthrange'] + reference_file_types = ["photom", "wavelengthrange"] - def process(self, input_model, *args, **kwargs): + def process(self, input_model): + """ + Run the WFSS contamination correction step. - with datamodels.open(input_model) as dm: + Parameters + ---------- + input_model : `~jwst.datamodels.MultiSlitModel` + The input data model containing 2-D cutouts for each identified source. + Returns + ------- + output_model : `~jwst.datamodels.MultiSlitModel` + A copy of the input_model with contamination removed + """ + with datamodels.open(input_model) as dm: max_cores = self.maximum_cores # Get the wavelengthrange ref file - waverange_ref = self.get_reference_file(dm, 'wavelengthrange') - self.log.info(f'Using WAVELENGTHRANGE reference file {waverange_ref}') + waverange_ref = self.get_reference_file(dm, "wavelengthrange") + self.log.info(f"Using WAVELENGTHRANGE reference file {waverange_ref}") waverange_model = datamodels.WavelengthrangeModel(waverange_ref) # Get the photom ref file - photom_ref = self.get_reference_file(dm, 'photom') - self.log.info(f'Using PHOTOM reference file {photom_ref}') + photom_ref = self.get_reference_file(dm, "photom") + self.log.info(f"Using PHOTOM reference file {photom_ref}") photom_model = datamodels.open(photom_ref) - result, simul, contam = wfss_contam.contam_corr(dm, - waverange_model, - photom_model, - max_cores) + result, simul, contam = wfss_contam.contam_corr( + dm, waverange_model, photom_model, max_cores + ) # Save intermediate results, if requested if self.save_simulated_image: From ecebb77e6545ff55d807b1b92c7a16d423fb829e Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Feb 2025 09:50:02 -0500 Subject: [PATCH 2/4] revert removal of id from dispersed_pixel --- jwst/regtest/test_nircam_wfss_contam.py | 2 +- jwst/wfss_contam/disperse.py | 7 ++++++- jwst/wfss_contam/observations.py | 2 +- jwst/wfss_contam/tests/test_observations.py | 10 +++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/jwst/regtest/test_nircam_wfss_contam.py b/jwst/regtest/test_nircam_wfss_contam.py index 0b50b4935f..93dc5d8bd9 100644 --- a/jwst/regtest/test_nircam_wfss_contam.py +++ b/jwst/regtest/test_nircam_wfss_contam.py @@ -27,7 +27,7 @@ def run_wfss_contam(rtdata_module): rtdata = rt.run_step_from_dict(rtdata, **step_params) return rtdata -@pytest.mark.skip(reason='Test too slow until stdatamodels PR#165 merged') +#@pytest.mark.skip(reason='Test too slow until stdatamodels PR#165 merged') @pytest.mark.bigdata @pytest.mark.parametrize( 'suffix', diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index 04f9c65da0..72e5ad7770 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -21,6 +21,7 @@ def dispersed_pixel( sens_resp, seg_wcs, grism_wcs, + source_id, naxis, oversample_factor=2, extrapolate_sed=False, @@ -62,6 +63,10 @@ def dispersed_pixel( The WCS object of the segmentation map. grism_wcs : WCS object The WCS object of the grism image. + source_id : int + The source ID of the source being processed. Returned in the output unmodified; + used only for bookkeeping. TODO this is not implemented properly right now and + should probably just be removed. naxis : tuple Dimensions (shape) of grism image into which pixels are dispersed. oversample_factor : int @@ -174,4 +179,4 @@ def flux(_x): counts = flux(lams) * areas / (sens * oversample_factor) counts[no_cal] = 0.0 # set to zero where no flux cal info available - return xs, ys, areas, lams, counts + return xs, ys, areas, lams, counts, source_id diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 043454d0ce..c68e2c496e 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -335,7 +335,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): self.sens_resp, self.seg_wcs, self.grism_wcs, - 0, + i, # TODO: this is not the source_id as the docstring to dispersed_pixel says self.dims[::-1], 2, self.extrapolate_sed, diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index 47b64969bc..f4ee4dc4e0 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -86,7 +86,7 @@ def test_disperse_oversample_same_result(grism_wcs, segmentation_map): height = 1.0 lams = [2.0] flxs = [1.0] - ID = 0 + source_id = 0 naxis = (300, 500) sens_waves = np.linspace(1.708, 2.28, 100) wmin, wmax = np.min(sens_waves), np.max(sens_waves) @@ -97,15 +97,15 @@ def test_disperse_oversample_same_result(grism_wcs, segmentation_map): yoffset = 1000 - xs, ys, areas, lams_out, counts_1 = dispersed_pixel( + xs, ys, areas, lams_out, counts_1, source_id = dispersed_pixel( x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, naxis, + sens_waves, sens_resp, seg_wcs, grism_wcs, source_id, naxis, oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) - xs, ys, areas, lams_out, counts_3 = dispersed_pixel( + xs, ys, areas, lams_out, counts_3, source_id = dispersed_pixel( x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, naxis, + sens_waves, sens_resp, seg_wcs, grism_wcs, source_id, naxis, oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) From 2b09c4116ab736ea65574b88c76a530f59eb9fb8 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Feb 2025 11:07:35 -0500 Subject: [PATCH 3/4] re-ignore regtest --- jwst/regtest/test_nircam_wfss_contam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jwst/regtest/test_nircam_wfss_contam.py b/jwst/regtest/test_nircam_wfss_contam.py index 93dc5d8bd9..0b50b4935f 100644 --- a/jwst/regtest/test_nircam_wfss_contam.py +++ b/jwst/regtest/test_nircam_wfss_contam.py @@ -27,7 +27,7 @@ def run_wfss_contam(rtdata_module): rtdata = rt.run_step_from_dict(rtdata, **step_params) return rtdata -#@pytest.mark.skip(reason='Test too slow until stdatamodels PR#165 merged') +@pytest.mark.skip(reason='Test too slow until stdatamodels PR#165 merged') @pytest.mark.bigdata @pytest.mark.parametrize( 'suffix', From b3d1992df91d9827bf8e09c9a7824bcf0948c7c1 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Feb 2025 12:17:47 -0500 Subject: [PATCH 4/4] Update optional args to Observation init --- jwst/wfss_contam/observations.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index c68e2c496e..8b975b4c85 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -97,26 +97,26 @@ def __init__( WCS object from grism image filter_name : str Filter name - source_id : int - ID of source to process. If zero, all sources processed. - sed_file : str + source_id : int, optional, default 0 + ID of source to process. If 0, all sources processed. + sed_file : str, optional, default None Name of Spectral Energy Distribution (SED) file containing datasets matching the ID in the segmentation file and each consisting of a [[lambda],[flux]] array. - extrapolate_sed : bool + extrapolate_sed : bool, optional, default False Flag indicating whether to extrapolate wavelength range of SED - boundaries : tuple + boundaries : list, optional, default [] Start/Stop coordinates of the FOV within the larger seed image. - offsets : tuple + offsets : list, optional, default [0,0] Offset values for x and y axes - renormalize : bool + renormalize : bool, optional, default True Flag indicating whether to renormalize SED's - max_cpu : int + max_cpu : int, optional, default 1 Max number of cpu's to use when multiprocessing """ if boundaries is None: - boundaries = () + boundaries = [] if offsets is None: - offsets = (0, 0) + offsets = [0, 0] # Load all the info for this grism mode self.seg_wcs = segmap_model.meta.wcs self.grism_wcs = grism_wcs