"""Utility functions for applying flat field corrections."""
import logging
import math
import warnings
import numpy as np
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags
from jwst.assign_wcs import nirspec
from jwst.lib import pipe_utils, reffile_utils, wcs_utils
log = logging.getLogger(__name__)
MICRONS_100 = 1.0e-4 # 100 microns, in meters
# NIRSpec exposure type categories
FIXED_SLIT_TYPES = ["NRS_LAMP", "NRS_AUTOWAVE", "NRS_BRIGHTOBJ", "NRS_FIXEDSLIT"]
NIRSPEC_SPECTRAL_EXPOSURES = [
"NRS_AUTOWAVE",
"NRS_BRIGHTOBJ",
"NRS_FIXEDSLIT",
"NRS_IFU",
"NRS_LAMP",
"NRS_MSASPEC",
]
# Dispersion direction, predominantly horizontal or vertical. These values
# are to be compared with keyword DISPAXIS from the input header.
HORIZONTAL = 1
VERTICAL = 2
BADFLAT = (
dqflags.pixel["NO_FLAT_FIELD"] | dqflags.pixel["DO_NOT_USE"] | dqflags.pixel["UNRELIABLE_FLAT"]
)
__all__ = [
"do_correction",
"do_flat_field",
"apply_flat_field",
"do_nirspec_flat_field",
"nirspec_fs_msa",
"nirspec_brightobj",
"nirspec_ifu",
"create_flat_field",
"fore_optics_flat",
"spectrograph_flat",
"detector_flat",
"combine_dq",
"read_image_wl",
"read_flat_table",
"combine_fast_slow",
"clean_wl",
"interpolate_flat",
"flat_for_nirspec_ifu",
"flat_for_nirspec_brightobj",
"flat_for_nirspec_slit",
]
[docs]
def do_correction(
input_model,
flat=None,
fflat=None,
sflat=None,
dflat=None,
user_supplied_flat=None,
inverse=False,
):
"""
Flat-field a JWST data model using a flat-field model.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Input science data model to be flat-fielded. Updated in place.
flat : `~stdatamodels.jwst.datamodels.JwstDataModel` or None, optional
Data model containing flat-field for all instruments other than
NIRSpec spectrographic data.
fflat : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None, optional
Flat field for the fore optics. Used only for NIRSpec data.
sflat : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None, optional
Flat field for the spectrograph. Used only for NIRSpec data.
dflat : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None, optional
Flat field for the detector. Used only for NIRSpec data.
user_supplied_flat : `~stdatamodels.jwst.datamodels.JwstDataModel` or None, optional
If supplied, all other reference flats and flat creation are
ignored in favor of the specified flat.
inverse : bool, optional
Invert the math operations used to apply the flat field.
Returns
-------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
The updated data model with flat-fielded science data.
flat_applied : `~stdatamodels.jwst.datamodels.MultiSlitModel` or \
`~stdatamodels.jwst.datamodels.ImageModel`
Data model containing the interpolated flat fields (NIRSpec data only), or
just the input flat.
"""
# NIRSpec spectrographic data are processed differently from other
# types of data (including NIRSpec imaging). The test on flat is
# needed because NIRSpec imaging data are processed by do_flat_field().
if (input_model.meta.exposure.type in NIRSPEC_SPECTRAL_EXPOSURES) and (
input_model.meta.instrument.lamp_mode != "IMAGE"
):
flat_applied = do_nirspec_flat_field(
input_model,
fflat,
sflat,
dflat,
user_supplied_flat=user_supplied_flat,
inverse=inverse,
)
else:
if user_supplied_flat is not None:
flat = user_supplied_flat
if flat is None:
log.warning("No flat found or supplied; step will be skipped.")
input_model.meta.cal_step.flat_field = "SKIPPED"
flat_applied = None
else:
do_flat_field(input_model, flat, inverse=inverse)
flat_applied = flat
return input_model, flat_applied
[docs]
def do_flat_field(output_model, flat_model, inverse=False):
"""
Apply image-based flat-fielding, updating the output model.
This method uses a simple flat field image to correct the input.
It is used for all exposure types except NIRSpec spectroscopic
modes, which require a composite flat field.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Flat-fielded input science data model, modified in-place.
flat_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Data model containing flat-field.
inverse : bool, optional
Invert the math operations used to apply the flat field.
"""
if output_model.meta.instrument.name == "NIRSPEC":
log.debug("Flat field correction for NIRSpec imaging data.")
else:
log.debug("Flat field correction for non-NIRSpec modes.")
any_updated = False # will set True if any flats applied
# Check to see if flat data array is smaller than science data
if (output_model.data.shape[-1] > flat_model.data.shape[-1]) or (
output_model.data.shape[-2] > flat_model.data.shape[-2]
):
log.warning("Reference data array is smaller than science data")
log.warning("Step will be skipped")
elif isinstance(output_model, datamodels.MultiSlitModel):
# Apply flat to each slit contained in the input
for slit in output_model.slits:
log.debug("Applying flat to slit %s", slit.name)
apply_flat_field(slit, flat_model, inverse=inverse)
any_updated = True
else:
# Apply flat to all other models
apply_flat_field(output_model, flat_model, inverse=inverse)
any_updated = True
if any_updated:
output_model.meta.cal_step.flat_field = "COMPLETE"
else:
output_model.meta.cal_step.flat_field = "SKIPPED"
[docs]
def apply_flat_field(science, flat, inverse=False):
"""
Flat field the data and update error and variance arrays.
The science data is divided by the flat field.
The data quality array is updated based on bad pixels in flat
field arrays. Error and variance arrays are propagated appropriately.
Applies the portion of the flat field corresponding to the science
image subarray.
Parameters
----------
science : `~stdatamodels.jwst.datamodels.JwstDataModel`
Input science data model.
flat : `~stdatamodels.jwst.datamodels.JwstDataModel`
Flat field data model.
inverse : bool, optional
Invert the math operations used to apply the flat field.
"""
# Extract subarray from reference data, if necessary
if reffile_utils.ref_matches_sci(science, flat):
flat_data = flat.data
flat_dq = flat.dq
flat_err = flat.err
else:
log.info("Extracting matching subarray from flat")
sub_flat = reffile_utils.get_subarray_model(science, flat)
flat_data = sub_flat.data.copy()
flat_dq = sub_flat.dq.copy()
flat_err = sub_flat.err.copy()
sub_flat.close()
# Find pixels in the flat that have a value of NaN and set
# DQ = DO_NOT_USE + NO_FLAT_FIELD
bad_flag = dqflags.pixel["DO_NOT_USE"] + dqflags.pixel["NO_FLAT_FIELD"]
flat_nan = np.isnan(flat_data)
flat_dq[flat_nan] = np.bitwise_or(flat_dq[flat_nan], bad_flag)
# Find pixels in the flat that have a value of zero, and set
# DQ = DO_NOT_USE + NO_FLAT_FIELD
flat_zero = np.where(flat_data == 0.0)
flat_dq[flat_zero] = np.bitwise_or(flat_dq[flat_zero], bad_flag)
# Find pixels in the flat that have the DQ bit for NO_FLAT_FIELD set
# Set the DO_NOT_USE flag for such pixels
flat_noflat = np.where(np.bitwise_and(flat_dq, dqflags.pixel["NO_FLAT_FIELD"]))
flat_dq[flat_noflat] = np.bitwise_or(flat_dq[flat_noflat], dqflags.pixel["DO_NOT_USE"])
# Find all pixels in the flat that have a DQ value of DO_NOT_USE
flat_bad = np.bitwise_and(flat_dq, dqflags.pixel["DO_NOT_USE"])
# Reset the flat value of all bad pixels to 1.0, so that no
# correction is made
flat_data[np.where(flat_bad)] = 1.0
# Now let's apply the correction to science data and error arrays. Rely
# on array broadcasting to handle the cubes
flat_data_squared = flat_data * flat_data
if not inverse:
science.data /= flat_data
# NOTE: Re-arranging var_flat math will cause some NIRISS SOSS
# regression tests to fail with floating point differences.
science.var_flat = science.data**2 / flat_data_squared * flat_err**2
else:
science.data *= flat_data
# Var_flat does not exist before flatfield step - set it to zero.
science.var_flat = np.zeros_like(science.data)
# Update the variances using BASELINE algorithm. For guider data, it has
# not gone through ramp fitting so there is no Poisson noise or readnoise
if not isinstance(science, datamodels.GuiderCalModel):
if not inverse:
science.var_poisson /= flat_data_squared
science.var_rnoise /= flat_data_squared
science.err = np.sqrt(science.var_poisson + science.var_rnoise + science.var_flat)
else:
science.var_poisson *= flat_data_squared
science.var_rnoise *= flat_data_squared
science.err = np.sqrt(science.var_poisson + science.var_rnoise)
elif not inverse:
# Set the output ERR to be the combined input ERR plus flatfield ERR, summed in quadrature
science.err = np.sqrt(science.err**2 + science.var_flat)
# Combine the science and flat DQ arrays
science.dq = np.bitwise_or(science.dq, flat_dq)
# Make sure all NaNs and flags match up in the output model
pipe_utils.match_nans_and_flags(science)
[docs]
def do_nirspec_flat_field(
output_model, f_flat_model, s_flat_model, d_flat_model, user_supplied_flat=None, inverse=False
):
"""
Apply flat-fielding for NIRSpec spectroscopic data, updating in-place.
Calls one of 3 functions depending on whether the data is:
1. NIRSpec IFU,
2. NIRSpec BRIGHTOBJ, or
3. NIRSpec MSA or Fixed Slit.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Science data model, modified (flat fielded) in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
user_supplied_flat : `~stdatamodels.jwst.datamodels.JwstDataModel` or None, optional
If provided, override all other calculated or reference-file-retrieved
flat information and use this data.
inverse : bool, optional
Invert the math operations used to apply the flat field.
Returns
-------
result : `~stdatamodels.jwst.datamodels.MultiSlitModel` or \
`~stdatamodels.jwst.datamodels.ImageModel`
The interpolated flat field(s).
"""
log.debug("Flat field correction for NIRSpec spectrographic data.")
exposure_type = output_model.meta.exposure.type
try:
dispaxis = output_model.meta.wcsinfo.dispersion_direction
except AttributeError:
if len(output_model.slits) > 0:
dispaxis = output_model.slits[0].meta.wcsinfo.dispersion_direction
else:
dispaxis = None
if dispaxis is None:
log.warning("Can't determine dispaxis, assuming horizontal.")
dispaxis = HORIZONTAL
if exposure_type == "NRS_BRIGHTOBJ":
if not isinstance(output_model, datamodels.SlitModel):
log.error("NIRSpec BRIGHTOBJ data is not a SlitModel; don't know how to process it.")
raise TypeError(f"Input is {type(output_model)}; expected SlitModel")
return nirspec_brightobj(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=user_supplied_flat,
inverse=inverse,
)
# We expect NIRSpec IFU data to be an IFUImageModel, but it's conceivable
# that the slices have been copied out into a MultiSlitModel, so
# check for that case.
if not hasattr(output_model, "slits"):
if exposure_type == "NRS_IFU" or (
exposure_type in ["NRS_AUTOWAVE", "NRS_LAMP"]
and output_model.meta.instrument.lamp_mode == "IFU"
):
if not isinstance(output_model, datamodels.IFUImageModel):
log.error("NIRSpec IFU data is not an IFUImageModel; don't know how to process it.")
raise TypeError(f"Input is {type(output_model)}; expected IFUImageModel")
return nirspec_ifu(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=user_supplied_flat,
inverse=inverse,
)
else:
raise TypeError(f"No flat field algorithm exists for handling data {output_model}")
# For datamodels with slits, MSA and Fixed slit modes:
else:
return nirspec_fs_msa(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=user_supplied_flat,
inverse=inverse,
)
[docs]
def nirspec_fs_msa(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=None,
inverse=False,
):
"""
Apply flat-fielding for NIRSpec fixed slit and MSA data, in-place.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
`~stdatamodels.jwst.datamodels.MultiSlitModel`,
modified (flat fielded) slit-by-slit, in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
user_supplied_flat : `~stdatamodels.jwst.datamodels.JwstDataModel` or None, optional
If provided, override all other calculated or reference-file-retrieved
flat information and use this data.
inverse : bool, optional
Invert the math operations used to apply the flat field.
Returns
-------
interpolated_flats : `~stdatamodels.jwst.datamodels.MultiSlitModel`
The interpolated flat field, one for each slit.
"""
exposure_type = output_model.meta.exposure.type
# Create a list to hold the list of slits. This will eventually be used
# to extend the MultiSlitModel.slits attribute. We do it this way to
# postpone validation until the end, which is faster.
flat_slits = []
# A flag to make sure at least one slit was flat fielded, so we can set
# "COMPLETE", otherwise we set "SKIP"
any_updated = False
for slit_idx, slit in enumerate(output_model.slits):
log.info("Working on slit %s", slit.name)
if exposure_type == "NRS_MSASPEC":
slit_nt = slit # includes quadrant info
else:
slit_nt = None
if user_supplied_flat is not None:
slit_flat = user_supplied_flat.slits[slit_idx]
else:
if exposure_type == "NRS_FIXEDSLIT" and slit.source_type.upper() == "POINT":
# For fixed-slit exposures, if this contains a point source,
# compute the flat-field corrections for both uniform
# (without wavecorr) and point
# source (with wavecorr) modes, applying only the point
# source version to the data.
# First compute a flat appropriate for a uniform source,
# which means NOT using corrected wavelengths
slit_flat = flat_for_nirspec_slit(
slit,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
exposure_type,
slit_nt,
output_model.meta.subarray,
use_wavecorr=False,
)
# Store the result for uniform source
slit.flatfield_uniform = slit_flat.data
# Now compute a flat appropriate for a point source,
# which means using corrected wavelengths
slit_flat = flat_for_nirspec_slit(
slit,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
exposure_type,
slit_nt,
output_model.meta.subarray,
use_wavecorr=True,
)
# Store the result for point source; this will be
# the version actually applied to the data below.
slit.flatfield_point = slit_flat.data
else:
# Build the flat for this slit the normal way, without any
# specification for whether we want to use corrected wavelengths
slit_flat = flat_for_nirspec_slit(
slit,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
exposure_type,
slit_nt,
output_model.meta.subarray,
use_wavecorr=None,
)
# Append the SlitDataModel to the list of slits
flat_slits.append(slit_flat)
# Now let's apply the correction to science data and error arrays. Rely
# on array broadcasting to handle the cubes.
# Also update the variances using BASELINE algorithm.
flat_data_squared = slit_flat.data * slit_flat.data
if not inverse:
slit.data /= slit_flat.data
slit.var_poisson /= flat_data_squared
slit.var_rnoise /= flat_data_squared
# NIRSpec flats have very small values: some variance values may overflow
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered in square", RuntimeWarning)
slit.var_flat = (slit.data / slit_flat.data * slit_flat.err) ** 2
slit.err = np.sqrt(slit.var_poisson + slit.var_rnoise + slit.var_flat)
else:
slit.data *= slit_flat.data
slit.var_poisson *= flat_data_squared
slit.var_rnoise *= flat_data_squared
# Var_flat does not exist before flatfield step - set it to zero.
slit.var_flat = np.zeros_like(slit.data)
slit.err = np.sqrt(slit.var_poisson + slit.var_rnoise)
# Combine the science and flat DQ arrays
slit.dq |= slit_flat.dq
# Make sure all NaNs and flags match up in the output model
pipe_utils.match_nans_and_flags(slit)
any_updated = True
if any_updated:
output_model.meta.cal_step.flat_field = "COMPLETE"
else:
output_model.meta.cal_step.flat_field = "SKIPPED"
# Create an output model for the interpolated flat fields.
if user_supplied_flat:
interpolated_flat = user_supplied_flat
else:
interpolated_flat = datamodels.MultiSlitModel()
interpolated_flat.update(output_model, only="PRIMARY")
interpolated_flat.slits.extend(flat_slits)
return interpolated_flat
[docs]
def nirspec_brightobj(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=None,
inverse=False,
):
"""
Apply flat-fielding for NIRSpec BRIGHTOBJ data, in-place.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
CubeModel, modified (flat fielded) plane by plane, in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
user_supplied_flat : `~stdatamodels.jwst.datamodels.ImageModel` or None, optional
A pre-computed flat to use directly. If supplied,
all other inputs are ignored.
inverse : bool, optional
Invert the math operations used to apply the flat field.
Returns
-------
result : `~stdatamodels.jwst.datamodels.ImageModel`
The interpolated flat field.
"""
if user_supplied_flat is not None:
log.info(f"Pre-computed flat {user_supplied_flat} provided. Using the flat directly")
interpolated_flat = user_supplied_flat
else:
interpolated_flat = flat_for_nirspec_brightobj(
output_model, f_flat_model, s_flat_model, d_flat_model, dispaxis
)
# Update the variances and uncertainty array using BASELINE algorithm
flat_data_squared = interpolated_flat.data * interpolated_flat.data
if not inverse:
output_model.data /= interpolated_flat.data
output_model.var_poisson /= flat_data_squared
output_model.var_rnoise /= flat_data_squared
# NIRSpec flats have very small values: some variance values may overflow
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered in square", RuntimeWarning)
output_model.var_flat = (
output_model.data / interpolated_flat.data * interpolated_flat.err
) ** 2
output_model.err = np.sqrt(
output_model.var_poisson + output_model.var_rnoise + output_model.var_flat
)
else:
output_model.data *= interpolated_flat.data
output_model.var_poisson *= flat_data_squared
output_model.var_rnoise *= flat_data_squared
# Var_flat does not exist before flatfield step - set it to zero.
output_model.var_flat = np.zeros_like(output_model.data)
output_model.err = np.sqrt(output_model.var_poisson + output_model.var_rnoise)
output_model.dq |= interpolated_flat.dq
# Make sure all NaNs and flags match up in the output model
pipe_utils.match_nans_and_flags(output_model)
output_model.meta.cal_step.flat_field = "COMPLETE"
return interpolated_flat
[docs]
def nirspec_ifu(
output_model,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
user_supplied_flat=None,
inverse=False,
):
"""
Apply flat-fielding for NIRSpec IFU data, in-place.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Science data model, modified (flat fielded) in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
user_supplied_flat : `~stdatamodels.jwst.datamodels.ImageModel` or None, optional
A pre-computed flat to use directly. If supplied,
all other inputs are ignored
inverse : bool, optional
Invert the math operations used to apply the flat field.
Returns
-------
result : `~stdatamodels.jwst.datamodels.ImageModel`
The interpolated flat field.
"""
if user_supplied_flat is not None:
log.info(f"Pre-computed flat {user_supplied_flat} provided. Using the flat directly")
flat = user_supplied_flat.data
flat_dq = user_supplied_flat.dq
flat_err = user_supplied_flat.err
any_updated = True
else:
flat, flat_dq, flat_err, any_updated = flat_for_nirspec_ifu(
output_model, f_flat_model, s_flat_model, d_flat_model, dispaxis
)
if any_updated:
# Update the variances and uncertainty array using BASELINE algorithm
flat_data_squared = flat * flat
if not inverse:
output_model.data /= flat
output_model.var_poisson /= flat_data_squared
output_model.var_rnoise /= flat_data_squared
# NIRSpec flats have very small values: some variance values may overflow
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered in square", RuntimeWarning)
output_model.var_flat = (output_model.data / flat * flat_err) ** 2
output_model.err = np.sqrt(
output_model.var_poisson + output_model.var_rnoise + output_model.var_flat
)
else:
output_model.data *= flat
output_model.var_poisson *= flat_data_squared
output_model.var_rnoise *= flat_data_squared
# Var_flat does not exist before flatfield step - set it to zero.
output_model.var_flat = np.zeros_like(output_model.data)
output_model.err = np.sqrt(output_model.var_poisson + output_model.var_rnoise)
output_model.dq |= flat_dq
# Make sure all NaNs and flags match up in the output model
pipe_utils.match_nans_and_flags(output_model)
output_model.meta.cal_step.flat_field = "COMPLETE"
# Create an output model for the interpolated flat fields.
interpolated_flats = datamodels.ImageModel(data=flat, dq=flat_dq, err=flat_err)
interpolated_flats.update(output_model, only="PRIMARY")
else:
output_model.meta.cal_step.flat_field = "SKIPPED"
interpolated_flats = None
return interpolated_flats
[docs]
def create_flat_field(
wl,
f_flat_model,
s_flat_model,
d_flat_model,
xstart,
xstop,
ystart,
ystop,
exposure_type,
dispaxis,
slit_name,
slit_nt=None,
):
"""
Extract and combine flat field components for NIRSpec.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array. This array has
shape ``(ystop - ystart, xstop - xstart)``.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
xstart, xstop, ystart, ystop : int
Starting and end pixel numbers (zero indexed) for the slice containing
the data for the current slit. The start and stop values are Python
slice notation, i.e., the region to be extracted is
``[ystart:ystop, xstart:xstop]``.
exposure_type : str
Exposure type for the input.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
slit_name : str or None
The name of the slit currently being processed.
slit_nt : namedtuple or None, optional
For MSA data only, info about the current slit.
Returns
-------
flat_2d : ndarray of float
The flat field, interpolated over wavelength, same shape as ``wl``.
Divide the 2-D extracted spectrum by this array to correct for
flat-field variations.
flat_dq : ndarray of uint32
The data quality array corresponding to ``flat_2d``.
flat_err : ndarray of float
The error array corresponding to ``flat_2d``.
"""
f_flat, f_flat_dq, f_flat_err = fore_optics_flat(
wl, f_flat_model, exposure_type, dispaxis, slit_name, slit_nt
)
s_flat, s_flat_dq, s_flat_err = spectrograph_flat(
wl, s_flat_model, xstart, xstop, ystart, ystop, exposure_type, dispaxis, slit_name
)
d_flat, d_flat_dq, d_flat_err = detector_flat(
wl, d_flat_model, xstart, xstop, ystart, ystop, exposure_type, dispaxis, slit_name
)
flat_2d = f_flat * s_flat * d_flat
flat_dq = combine_dq(f_flat_dq, s_flat_dq, d_flat_dq, default_shape=flat_2d.shape)
# Combine the uncertainty arrays, excluding the ones that are None
# Divide error by flat before squaring to avoid overflow
sum_var = np.zeros_like(flat_2d)
if f_flat_err is not None:
np.place(f_flat, f_flat == 0, 1.0)
sum_var += (f_flat_err / f_flat) ** 2
if s_flat_err is not None:
np.place(s_flat, s_flat == 0, 1.0)
sum_var += (s_flat_err / s_flat) ** 2
if d_flat_err is not None:
np.place(d_flat, d_flat == 0, 1.0)
sum_var += (d_flat_err / d_flat) ** 2
flat_err = flat_2d * np.sqrt(sum_var)
mask = np.bitwise_and(flat_dq, dqflags.pixel["DO_NOT_USE"])
flat_2d[np.where(mask)] = 1.0
return flat_2d, flat_dq, flat_err
[docs]
def fore_optics_flat(wl, f_flat_model, exposure_type, dispaxis, slit_name, slit_nt):
"""
Extract the flat for the fore optics part.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
exposure_type : str
The exposure type refers to fixed_slit, IFU, or MSA.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
slit_name : str or None
The name of the slit currently being processed.
slit_nt : namedtuple or None
For MOS data (only), this is used to get the quadrant number and
the indices of the current shutter in the Y and X directions.
Returns
-------
f_flat : ndarray of float32
The computed flat field for the fore optics.
f_flat_dq : ndarray of uint32
The associated data quality array.
f_flat_err : ndarray of float32 or None
The associated error array.
"""
if f_flat_model is None:
f_flat = np.ones(wl.shape, dtype=np.float32)
f_flat_dq = None
f_flat_err = None
return f_flat, f_flat_dq, f_flat_err
if slit_nt is None:
quadrant = None
else:
quadrant = slit_nt.quadrant - 1 # convert to zero indexed
(tab_wl, tab_flat, tab_flat_err) = read_flat_table(
f_flat_model, exposure_type, slit_name, quadrant
)
if tab_wl.max() < MICRONS_100:
log.warning("Wavelengths in f_flat table appear to be in meters")
# While there actually is a slowly varying flat field for the MSA mode,
# it's a 1-D array, not 2-D. This array will be applied by incorporating
# it into tab_flat. So even for the MSA, there will not be any 2-D
# image flat, so set the variable to 1.
flat_2d = 1.0
f_flat_dq = None
f_flat_err = None
if exposure_type == "NRS_MSASPEC":
# The MOS "image" is in MSA coordinates (shutter index in x and y),
# not detector pixel coordinates.
# This is an example to show what xcen and ycen mean:
# shutter_id = xcen + (ycen - 1) * 365
msa_y, msa_x = slit_nt.ycen, slit_nt.xcen
msa_y -= 1 # convert to zero indexed
msa_x -= 1
full_array_flat = f_flat_model.quadrants[quadrant].data
full_array_err = f_flat_model.quadrants[quadrant].err
# Get the wavelength corresponding to each plane in the "image".
image_wl = read_image_wl(f_flat_model, quadrant)
if image_wl.max() < MICRONS_100:
log.warning("Wavelengths in f_flat image appear to be in meters.")
one_d_flat = full_array_flat[:, msa_y, msa_x]
one_d_err = full_array_err[:, msa_y, msa_x]
# The wavelengths and flat-field values read from the reference
# table are tab_wl and tab_flat respectively. We need to combine
# the 1-D MSA flat from the reference "image" with the table
# values, so interpolate the 1-D MSA flat at each wavelength in
# tab_wl, and multiply into tab_flat.
# numpy.interp(x, xp, fp, left, right)
# x: values at which to interpolate
# xp: array of independent-variable values (must be increasing)
# fp: array of data values
# left, right: values to return for out-of-bounds x
tab_flat *= np.interp(tab_wl, image_wl, one_d_flat, 1.0, 1.0)
f_flat_err = np.interp(wl, image_wl, one_d_err, 0.0, 0.0)
# Combine 2D and 1D components with error propagation.
# The shape of the output array is obtained from wl.
f_flat, f_flat_dq, f_flat_err = combine_fast_slow(
wl, flat_2d, f_flat_dq, f_flat_err, tab_wl, tab_flat, tab_flat_err, dispaxis
)
# Find pixels in the flat that have a value of NaN and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
bad_flag = dqflags.pixel["DO_NOT_USE"] + dqflags.pixel["NO_FLAT_FIELD"]
flat_nan = np.isnan(f_flat)
f_flat_dq[flat_nan] = np.bitwise_or(f_flat_dq[flat_nan], bad_flag)
# Find pixels in the flat that have a value of zero, and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
flat_zero = np.where(f_flat == 0.0)
f_flat_dq[flat_zero] = np.bitwise_or(f_flat_dq[flat_zero], bad_flag)
# Find all pixels in the flat that have a DQ value of DO_NOT_USE
flat_bad = np.bitwise_and(f_flat_dq, dqflags.pixel["DO_NOT_USE"])
# Set the flat value of all bad pixels to nan. F-flats provide
# flux calibration scaling, so there is no safe default value for bad
# pixels.
f_flat[np.where(flat_bad)] = np.nan
return f_flat, f_flat_dq, f_flat_err
[docs]
def spectrograph_flat(
wl, s_flat_model, xstart, xstop, ystart, ystop, exposure_type, dispaxis, slit_name
):
"""
Extract the flat for the spectrograph part.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
xstart, xstop, ystart, ystop : int
Starting and end pixel numbers (zero indexed) for the slice containing
the data for the current slit. The start and stop values are Python
slice notation, i.e., the region to be extracted is ``[ystart:ystop, xstart:xstop]``.
exposure_type : str
The exposure type refers to fixed slit, IFU, or using the
micro-shutter array.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
slit_name : str or None
The name of the slit currently being processed.
Returns
-------
s_flat : ndarray of float32
The computed flat field for the spectrograph.
s_flat_dq : ndarray of uint32 or None
The associated data quality array.
s_flat_err : ndarray of float32 or None
The associated error array.
"""
if s_flat_model is None:
s_flat = np.ones(wl.shape, dtype=np.float32)
s_flat_dq = None
s_flat_err = None
return s_flat, s_flat_dq, s_flat_err
quadrant = None
if xstart >= xstop or ystart >= ystop:
return 1.0, None
(tab_wl, tab_flat, tab_flat_err) = read_flat_table(
s_flat_model, exposure_type, slit_name, quadrant
)
if tab_wl.max() < MICRONS_100:
log.warning("Wavelengths in s_flat table appear to be in meters")
full_array_flat = s_flat_model.data
full_array_dq = s_flat_model.dq
full_array_err = s_flat_model.err
if len(full_array_flat.shape) == 3:
# MSA data
image_flat = full_array_flat[:, ystart:ystop, xstart:xstop]
image_dq = full_array_dq[:, ystart:ystop, xstart:xstop]
image_err = full_array_err[:, ystart:ystop, xstart:xstop]
# Get the wavelength corresponding to each plane in the image.
image_wl = read_image_wl(s_flat_model, quadrant)
if image_wl.max() < MICRONS_100:
log.warning("Wavelengths in s_flat image appear to be in meters")
flat_2d, s_flat_dq, s_flat_err = interpolate_flat(
image_flat, image_dq, image_err, image_wl, wl
)
else:
flat_2d = full_array_flat[ystart:ystop, xstart:xstop]
s_flat_dq = full_array_dq[ystart:ystop, xstart:xstop]
s_flat_err = full_array_err[ystart:ystop, xstart:xstop]
# Find pixels in the flat that have a value of NaN and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
bad_flag = dqflags.pixel["DO_NOT_USE"] + dqflags.pixel["NO_FLAT_FIELD"]
flat_nan = np.isnan(flat_2d)
s_flat_dq[flat_nan] = np.bitwise_or(s_flat_dq[flat_nan], bad_flag)
# Find pixels in the flat have have a value of zero, and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
flat_zero = np.where(flat_2d == 0.0)
s_flat_dq[flat_zero] = np.bitwise_or(s_flat_dq[flat_zero], bad_flag)
# Find all pixels in the flat that have a DQ value of DO_NOT_USE
flat_bad = np.bitwise_and(s_flat_dq, dqflags.pixel["DO_NOT_USE"])
# Reset the flat value of all bad pixels to 1.0, so that no
# correction is made
flat_2d[np.where(flat_bad)] = 1.0
# Combine 2D and 1D components with error propagation
s_flat, s_flat_dq, s_flat_err = combine_fast_slow(
wl, flat_2d, s_flat_dq, s_flat_err, tab_wl, tab_flat, tab_flat_err, dispaxis
)
return s_flat, s_flat_dq, s_flat_err
[docs]
def detector_flat(
wl, d_flat_model, xstart, xstop, ystart, ystop, exposure_type, dispaxis, slit_name
):
"""
Extract the flat for the detector part.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
xstart, xstop, ystart, ystop : int
Starting and end pixel numbers (zero indexed) for the slice containing
the data for the current slit. The start and stop values are Python
slice notation, i.e., the region to be extracted is ``[ystart:ystop, xstart:xstop]``.
exposure_type : str
The exposure type refers to fixed slit, IFU, or using the
micro-shutter array.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
slit_name : str or None
The name of the slit currently being processed.
Returns
-------
d_flat : ndarray of float32
The computed flat field for the detector.
d_flat_dq : ndarray of uint32, or None
The associated data quality array.
d_flat_err : ndarray of float32, or None
The associated error array.
"""
if d_flat_model is None:
d_flat = np.ones(wl.shape, dtype=np.float32)
d_flat_dq = None
d_flat_err = None
return d_flat, d_flat_dq, d_flat_err
quadrant = None
if xstart >= xstop or ystart >= ystop:
return 1.0, None
(tab_wl, tab_flat, tab_flat_err) = read_flat_table(
d_flat_model, exposure_type, slit_name, quadrant
)
if tab_wl.max() < MICRONS_100:
log.warning("Wavelengths in d_flat table appear to be in meters.")
full_array_flat = d_flat_model.data
full_array_dq = d_flat_model.dq
full_array_err = d_flat_model.err
image_flat = full_array_flat[:, ystart:ystop, xstart:xstop]
image_dq = full_array_dq[..., ystart:ystop, xstart:xstop]
image_err = full_array_err[..., ystart:ystop, xstart:xstop]
# Get the wavelength corresponding to each plane in the image.
image_wl = read_image_wl(d_flat_model, quadrant)
if image_wl.max() < MICRONS_100:
log.warning("Wavelengths in d_flat image appear to be in meters.")
flat_2d, d_flat_dq, d_flat_err = interpolate_flat(image_flat, image_dq, image_err, image_wl, wl)
# Find pixels in the flat that have a value of NaN and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
bad_flag = dqflags.pixel["DO_NOT_USE"] + dqflags.pixel["NO_FLAT_FIELD"]
flat_nan = np.isnan(flat_2d)
d_flat_dq[flat_nan] = np.bitwise_or(d_flat_dq[flat_nan], bad_flag)
# Find pixels in the flat have have a value of zero, and add to
# DQ mask, DO_NOT_USE + NO_FLAT_FIELD
flat_zero = np.where(flat_2d == 0.0)
d_flat_dq[flat_zero] = np.bitwise_or(d_flat_dq[flat_zero], bad_flag)
# Find all pixels in the flat that have a DQ value of DO_NOT_USE
flat_bad = np.bitwise_and(d_flat_dq, dqflags.pixel["DO_NOT_USE"])
# Reset the flat value of all bad pixels to 1.0, so that no
# correction is made
flat_2d[np.where(flat_bad)] = 1.0
# Combine 2D and 1D components with error propagation
d_flat, d_flat_dq, d_flat_err = combine_fast_slow(
wl, flat_2d, d_flat_dq, d_flat_err, tab_wl, tab_flat, tab_flat_err, dispaxis
)
return d_flat, d_flat_dq, d_flat_err
[docs]
def combine_dq(f_flat_dq, s_flat_dq, d_flat_dq, default_shape):
"""
Combine non-None DQ arrays via bitwise or.
Parameters
----------
f_flat_dq : ndarray or None
The DQ array for the fore optics component.
s_flat_dq : ndarray or None
The DQ array for the spectrograph component.
d_flat_dq : ndarray or None
The DQ array for the detector component.
default_shape : tuple
If all three of the DQ arrays (see above) are None, use this shape
to create a DQ array filled with zero.
Returns
-------
flat_dq : ndarray of uint32
The 2D DQ array resulting from combining the input DQ arrays via
bitwise OR.
"""
dq_list = []
if f_flat_dq is not None:
dq_list.append(f_flat_dq)
if s_flat_dq is not None:
dq_list.append(s_flat_dq)
if d_flat_dq is not None:
dq_list.append(d_flat_dq)
n_dq = len(dq_list)
# Combine the component flat dq arrays. If there are none, make a
# dq array with all BADFLAT bits set
flat_dq = np.zeros(default_shape, dtype=np.uint32)
if n_dq == 0:
flat_dq = np.bitwise_or(flat_dq, BADFLAT)
else:
for dq_component in dq_list:
flat_dq = np.bitwise_or(flat_dq, dq_component)
# Flag DO_NOT_USE where some or all of the flats had NO_FLAT_FIELD set
do_not_use_loc = np.where(np.bitwise_and(flat_dq, dqflags.pixel["NO_FLAT_FIELD"]))
flat_dq[do_not_use_loc] = np.bitwise_or(flat_dq[do_not_use_loc], dqflags.pixel["DO_NOT_USE"])
# Flag DO_NOT_USE, NO_FLAT_FIELD and UNRELIABLE_FLAT where some or all the
# flats had DO_NOT_USE set
iloc = np.where(np.bitwise_and(flat_dq, dqflags.pixel["DO_NOT_USE"]))
flat_dq[iloc] = np.bitwise_or(flat_dq[iloc], BADFLAT)
return flat_dq
[docs]
def read_image_wl(flat_model, quadrant=None):
"""
Read wavelengths for the image planes.
This function should only be called if the SCI array in ``flat_model``
is 3-D. The purpose is to get the wavelength for each of the planes
in the SCI array.
Parameters
----------
flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`
Flat field for the current component.
quadrant : {0, 1, 2, 3} or None, optional
The quadrant of the micro-shutter array. This is only needed for
fore-optics for MSA (MOS) data.
Returns
-------
wavelength : ndarray
A 1D array of wavelengths, one for each plane of the SCI array.
"""
if quadrant is not None: # NRS_MSASPEC
wavelength = flat_model.quadrants[quadrant].wavelength["wavelength"]
else:
wavelength = flat_model.wavelength["wavelength"]
if len(wavelength.shape) > 1:
n = wavelength.shape[-1]
try:
wl = wavelength.reshape((n,))
except ValueError:
log.error(
"Image wavelength array has shape %s; don't know how to interpret that.",
str(wavelength.shape),
)
raise ValueError("Expected either a scalar column or just one row.") from None
wavelength = wl
# The assumption here is that any NaN or non-positive wavelengths will
# only be at the end of the array. If there are embedded NaNs or zero
# or negative wavelengths, the following will result in the wavelengths
# becoming out of synch with the flat-field data. In this case, it
# will be necessary to also filter (along the first image axis) and
# return the flat-field data.
filter1 = np.logical_not(np.isnan(wavelength)) # skip NaNs
wavelength = wavelength[filter1]
filter2 = wavelength > 0.0
wavelength = wavelength[filter2]
return wavelength
[docs]
def read_flat_table(flat_model, exposure_type, slit_name=None, quadrant=None):
"""
Read the table (the "fast" variation).
Parameters
----------
flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`
This contains the flat field table from which we will read the
"fast" variation flat-field data.
exposure_type : str
The exposure type refers to fixed slit, IFU, or using the
micro-shutter array. In this function we just need to check for
fixed-slit types.
slit_name : str or None, optional
The name of the slit. This is only needed for fixed-slit data, in
which case it is used for selecting the relevant row of the table.
quadrant : {0, 1, 2, 3} or None, optional
The quadrant of the micro-shutter array. This is only needed for
fore-optics for MSA (MOS) data.
Returns
-------
tab_wl : ndarray of float
The 1D column of wavelengths read from the fast-variation table.
tab_flat : ndarray of float
The 1D column of flat_field values read from the fast-variation table.
``tab_wl`` and ``tab_flat`` should be the same length.
tab_flat_err : ndarray of float
The 1D column of flat field error values read from the fast-variation
table.
"""
if quadrant is not None: # NRS_MSASPEC
data = flat_model.quadrants[quadrant].flat_table
else:
data = flat_model.flat_table
try:
slit_col = data["slit_name"]
except KeyError:
slit_col = None
try:
nelem_col = data["nelem"]
except KeyError:
nelem_col = None
wl_col = data["wavelength"]
flat_col = data["data"]
flat_err_col = data["error"]
nrows = len(wl_col)
row = None
# Note that it's only for fixed-slit data that we need to select the
# row based on the slit name.
if exposure_type in FIXED_SLIT_TYPES and slit_col is not None and slit_name is not None:
slit_name_lc = slit_name.lower()
for i in range(nrows):
# Note: The .strip() is a workaround. As of the time of
# writing, the value of a text string may have trailing blanks
# if there is only one row in the table.
column_value = slit_col[i].lower().strip()
if column_value == "any" or column_value == slit_name_lc:
row = i
break
if row is None:
log.error("Slit name %s not found in flat field table", slit_name)
raise ValueError(f"{slit_name} not found in SLIT_NAME column (nor was 'ANY')")
nelem = None # initial value
if row is not None:
# Table contains arrays; use the row that was found above.
tab_wl = wl_col[row].copy()
tab_flat = flat_col[row].copy()
tab_flat_err = flat_err_col[row].copy()
if nelem_col is not None:
nelem = nelem_col[row]
else:
# There was no SLIT_NAME column, or the data are not fixed-slit.
if len(wl_col.shape) > 1:
# Table contains arrays, but there should be only one row.
tab_wl = wl_col[0].copy()
tab_flat = flat_col[0].copy()
tab_flat_err = flat_err_col[0].copy()
if nelem_col is not None:
nelem = nelem_col[0]
else:
# Table contains scalar columns.
tab_wl = wl_col.copy()
tab_flat = flat_col.copy()
tab_flat_err = flat_err_col.copy()
if nelem_col is not None:
nelem = nelem_col[0] # arbitrary choice of row
if nelem is not None:
if len(tab_wl) < nelem:
log.error(
"The fast_variation array size %d in "
"the data model is too small, and table data were "
"truncated.",
len(tab_wl),
)
nelem = len(tab_wl) # truncated!
else:
tab_wl = tab_wl[:nelem]
tab_flat = tab_flat[:nelem]
tab_flat_err = tab_flat_err[:nelem]
else:
nelem = len(tab_wl)
# Trailing placeholder rows should have been taken care of via nelem above,
# but only if an nelem column was present and was set correctly.
filter1 = np.logical_not(np.isnan(tab_wl)) # skip NaNs
filter2 = np.logical_not(np.isnan(tab_flat))
filter_rows = np.logical_and(filter1, filter2)
n1 = filter_rows.sum(dtype=np.intp)
if n1 != nelem:
log.debug(
"The table wavelength or flat-field data array contained "
"%d NaNs; these have been skipped.",
nelem - n1,
)
tab_wl = tab_wl[filter_rows]
tab_flat = tab_flat[filter_rows]
tab_flat_err = tab_flat_err[filter_rows]
del filter1, filter2, filter_rows
# Skip zero or negative wavelengths, and skip zero flat-field values.
filter1 = tab_wl > 0.0
filter2 = tab_flat != 0.0
filter_rows = np.logical_and(filter1, filter2)
n2 = filter_rows.sum(dtype=np.intp)
if n2 != n1:
log.debug(
"The table wavelength or flat-field data array contained "
"%d zero or negative values; these have been skipped.",
n1 - n2,
)
tab_wl = tab_wl[filter_rows]
tab_flat = tab_flat[filter_rows]
tab_flat_err = tab_flat_err[filter_rows]
del filter1, filter2, filter_rows
# Check that the wavelengths are increasing. This is a requirement
# for using np.interp (see combine_fast_slow).
if len(tab_wl) > 1:
diff = tab_wl[1:] - tab_wl[0:-1]
if np.any(diff <= 0.0):
log.warning("Wavelengths in the fast-variation table must be strictly increasing.")
return tab_wl, tab_flat, tab_flat_err
[docs]
def combine_fast_slow(wl, flat_2d, flat_dq, flat_err, tab_wl, tab_flat, tab_flat_error, dispaxis):
"""
Multiply the image by the tabular values.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array.
flat_2d : ndarray or float
The flat field derived from the image part of the reference file,
or a scalar (e.g., 1.0) if there is no image part in the current
reference file.
flat_dq : ndarray or None
If not None, the data quality array corresponding to ``flat_2d``.
A copy of this will be updated with flags which may be set based
on the fast-variation component, and the updated array will be
returned.
flat_err : ndarray or None
If not None, the error array corresponding to ``flat_2d``.
A copy of this will be updated with errors added in quadrature from
the fast-variation component, and the updated array will be
returned.
tab_wl : ndarray
1D wavelengths corresponding to ``tab_flat``.
tab_flat : ndarray
The 1D flat field from the table part of the reference file. This
is the "fast" variation of the flat, i.e., fast with respect to
wavelength.
tab_flat_error : ndarray
The 1D flat field error from the table part of the reference file.
dispaxis : int
1 is horizontal, 2 is vertical.
Returns
-------
combined_flat : ndarray of float32
The 2D product of ``flat_2d`` and the values in ``tab_flat`` interpolated
to the wavelengths of the science image, i.e., ``wl``.
combined_dq : ndarray of uint32
The updated data quality array corresponding to ``flat_2d``. If a
pixel wavelength is less than or equal to zero, or if it's not
within the range of ``tab_wl``, NO_FLAT_FIELD will be used to flag
this condition, and the fast-variation flat field value at that
pixel will be set to 1.
flat_error : ndarray of float32
The ``tab_flat_err`` values interpolated to the 2D wavelengths of the
science image, i.e., ``wl``. Missing values are set to 0.
"""
wl_c = clean_wl(wl, dispaxis)
dwl = np.zeros_like(wl_c)
if flat_dq is None:
combined_dq = np.zeros(wl.shape, dtype=np.uint32)
else:
combined_dq = flat_dq.copy()
if flat_err is None:
combined_err = np.zeros(wl.shape, dtype=np.float64)
else:
combined_err = flat_err.copy()
if dispaxis == HORIZONTAL:
dwl[:, 0:-1] = wl_c[:, 1:] - wl_c[:, 0:-1]
dwl[:, -1] = dwl[:, -2]
elif dispaxis == VERTICAL:
dwl[0:-1, :] = wl_c[1:, :] - wl_c[0:-1, :]
dwl[-1, :] = dwl[-2, :]
# Values averaged within tab_flat.
values = np.zeros_like(wl_c)
# Abscissas and weights for 3-point Gaussian integration, but taking
# the width of the interval to be 1, so the result will be the average
# over the interval.
d = math.sqrt(0.6) / 2.0
dx = np.array([-d, 0.0, d])
wgt = np.array([5.0, 8.0, 5.0]) / 18.0
# Interpolate tabular data over the range of wavelengths,
# weight, and sum at each of 3 specified points
for offset, weight in zip(dx, wgt, strict=True):
wavelengths = wl_c + dwl * offset
values += weight * np.interp(wavelengths, tab_wl, tab_flat, left=np.nan, right=np.nan)
# Interpolate error values from reference file using a simple
# linear interpolation as these don't have the required precision
# to justify a more complex interpolation
error_value = np.interp(wl_c, tab_wl, tab_flat_error, left=np.nan, right=np.nan)
# Handle bad wavelength values in un-cleaned wavelength array
bad = wl <= 0
values[bad] = 1.0
error_value[bad] = 0.0
# Handle missing values
missing = np.isnan(values)
values[missing] = 1.0
error_value[missing] = 0.0
combined_dq[missing] |= dqflags.pixel["NO_FLAT_FIELD"]
combined_dq[missing] |= dqflags.pixel["DO_NOT_USE"]
# Add new 1D errors and input 2D errors in quadrature,
# treating NaNs as zeros
v1 = np.square(error_value * flat_2d)
v2 = np.square(combined_err * values)
combined_err = np.sqrt(np.nansum([v1, v2], axis=0))
return flat_2d * values, combined_dq, combined_err
[docs]
def clean_wl(wl, dispaxis):
"""
Replace zeros and/or NaNs in the wavelength array.
Parameters
----------
wl : ndarray
Wavelength at each pixel of the 2-D slit array.
dispaxis : int
1 is horizontal, 2 is vertical.
Returns
-------
wl_c : ndarray
A copy of ``wl``, but with zero and negative values replaced with
an average wavelength. For each column (row) in the dispersion
direction, the average to find a replacement value is taken along
the cross-dispersion direction.
"""
wl_c = wl.copy() # so we can replace zeros
wl_c[wl_c <= 0.0] = np.nan
shape = wl_c.shape
if dispaxis == HORIZONTAL:
i0 = None
i1 = None
for i in range(shape[1]):
temp = wl_c[:, i]
if np.any(np.isfinite(temp)):
if i0 is None:
i0 = i # first i with some non-zero wl
i1 = i # last i (so far) with some non-zero wl
replacement_value = np.nanmean(temp)
temp[np.isnan(temp)] = replacement_value
if i0 is not None and i0 > 0:
temp = wl_c[:, i0].reshape(shape[0], 1)
wl_c[:, 0:i0] = temp.copy()
if i1 is not None and i1 < shape[1] - 1:
temp = wl_c[:, i1].reshape(shape[0], 1)
wl_c[:, i1:] = temp.copy()
elif dispaxis == VERTICAL:
j0 = None
j1 = None
for j in range(shape[0]):
temp = wl_c[j, :]
if np.any(np.isfinite(temp)):
if j0 is None:
j0 = j # first j with some non-zero wl
j1 = j # last j (so far) with some non-zero wl
replacement_value = np.nanmean(temp)
temp[np.isnan(temp)] = replacement_value
if j0 is not None and j0 > 0:
temp = wl_c[j0, :].reshape(1, shape[1])
wl_c[0:j0, :] = temp.copy()
if j1 is not None and j1 < shape[0] - 1:
temp = wl_c[j1, :].reshape(1, shape[1])
wl_c[j1:, :] = temp.copy()
else:
wl_c = wl.copy()
return wl_c
[docs]
def interpolate_flat(image_flat, image_dq, image_err, image_wl, wl):
"""
Interpolate within the 3-D flat field image to get a 2-D flat.
Parameters
----------
image_flat : ndarray
3D array, corresponding to slice ``[:, ystart:ystop, xstart:xstop]`` of
the flat field reference image. This slice covers the spatial extent
of the extracted 2-D spectrum and includes all of the wavelength axis
(the first axis) of the reference image.
image_dq : ndarray
2D or 3D array corresponding to slice ``[..., ystart:ystop, xstart:xstop]``
of the data quality array for the flat field reference image.
image_err : ndarray
2D or 3D array corresponding to slice ``[..., ystart:ystop, xstart:xstop]``
of the error array for the flat field reference image.
image_wl : ndarray
The 1D wavelength for each plane of the flat field reference image.
wl : ndarray
The wavelength at each pixel of the 2D extracted science spectrum.
Returns
-------
flat_2d : ndarray of float32
The flat field, interpolated over wavelength, same shape as ``wl``.
Divide the 2D extracted spectrum by this array to correct for
flat-field variations.
flat_dq : ndarray of uint32
The data quality array corresponding to ``flat_2d``.
flat_err : ndarray of float32
The error array corresponding to ``flat_2d``.
"""
if len(image_flat.shape) < 3:
return image_flat, image_dq, image_err
(nz, ysize, xsize) = image_flat.shape
if nz == 1:
if len(image_dq.shape) == 2:
# dq and err arrays are the same size, so treat them the same
return image_flat.reshape((ysize, xsize)), image_dq, image_err
else:
return (
image_flat.reshape((ysize, xsize)),
image_dq.reshape((ysize, xsize)),
image_err.reshape((ysize, xsize)),
)
grid = np.indices((ysize, xsize), dtype=np.intp)
ixpixel = grid[1]
iypixel = grid[0]
# The initial value of -1 is a flag to indicate that elements have not
# been assigned valid values yet.
k = np.zeros(wl.shape, dtype=np.intp) - 1
# Truncate the index for wavelengths that are outside the range of
# image_wl. The indices need to be assigned harmless values to avoid
# indexing out of bounds.
# Why do we set the upper limit of k to nz - 2?
# Because we interpolate using elements k and k + 1.
# for wavelengths < lower limit (image_wl[0]) set k to 0
k[:, :] = np.where(wl <= image_wl[0], 0, k)
# for wavelengths > upper limit (image_wl[nz-1]) set k to nz-2
k[:, :] = np.where(wl >= image_wl[nz - 1], nz - 2, k)
# Look for the correct interval for linear interpolation.
for k_test in range(nz - 1):
test1 = np.logical_and(wl >= image_wl[k_test], wl < image_wl[k_test + 1])
# If an element of k is not -1, it has already been assigned, and
# I don't want to clobber it.
test2 = np.logical_and(k == -1, test1)
k[:, :] = np.where(test2, k_test, k)
if np.all(k >= 0):
break
# Use linear interpolation within the 3-D flat field to get a 2-D
# flat field.
denom = image_wl[k + 1] - image_wl[k]
zero_denom = denom == 0.0
denom = np.where(zero_denom, 1.0, denom)
# linear interpolation equation
# p is the linear interpolation wavelength scaling factor
# flat = flat[k] + (flat[k+1] - flat[k])*p
# flat = flat[k] - flat[k]*p + flat[k+1]*p
# flat = (1-p)*flat[k] + p*flat[k+1]
p = np.where(zero_denom, 0.0, (wl - image_wl[k]) / denom)
q = 1.0 - p
flat_2d = q * image_flat[k, iypixel, ixpixel] + p * image_flat[k + 1, iypixel, ixpixel]
if len(image_err.shape) == 2:
flat_err = image_err.copy()
else:
flat_err = q * image_err[k, iypixel, ixpixel] + p * image_err[k + 1, iypixel, ixpixel]
if len(image_dq.shape) == 2:
flat_dq = image_dq.copy()
else:
flat_dq = np.where(
p == 0.0,
image_dq[k, iypixel, ixpixel],
np.bitwise_or(image_dq[k, iypixel, ixpixel], image_dq[k + 1, iypixel, ixpixel]),
)
flat_bad = np.bitwise_and(flat_dq, dqflags.pixel["DO_NOT_USE"])
# Reset the flat value of all bad pixels to 1.0, so that no
# correction is made
flat_2d[np.where(flat_bad)] = 1.0
# If the wavelength at a pixel is outside the range of wavelengths
# for the reference image, flag the pixel as bad. Note that this will
# also result in the computed flat field being set to 1.
mask = wl < image_wl[0]
flat_dq[mask] = np.bitwise_or(flat_dq[mask], dqflags.pixel["DO_NOT_USE"])
mask = wl > image_wl[-1]
flat_dq[mask] = np.bitwise_or(flat_dq[mask], dqflags.pixel["DO_NOT_USE"])
# If a pixel is flagged as bad, applying flat_2d should not make any
# change to the science data.
flat_bad = np.bitwise_and(flat_dq, dqflags.pixel["DO_NOT_USE"])
flat_2d[np.where(flat_bad)] = 1.0
return flat_2d.astype(image_flat.dtype), flat_dq, flat_err
[docs]
def flat_for_nirspec_ifu(output_model, f_flat_model, s_flat_model, d_flat_model, dispaxis):
"""
Create the interpolated flat for NIRSpec IFU.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Science data model, modified (flat fielded) in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
Returns
-------
flat, flat_dq, flat_err : ndarray
The interpolated flat correction and associated DQ and error.
any_updated : bool
True if any slice of the IFU has been corrected.
"""
any_updated = False
exposure_type = output_model.meta.exposure.type
flat = np.ones_like(output_model.data) * np.nan
flat_dq = np.zeros_like(output_model.dq)
flat_err = np.zeros_like(output_model.data) * np.nan
try:
list_of_wcs = nirspec.nrs_ifu_wcs(output_model)
except (KeyError, AttributeError):
if output_model.meta.cal_step.assign_wcs == "COMPLETE":
log.error("The input file does not appear to have WCS info.")
raise RuntimeError("Problem accessing WCS information.") from None
else:
log.error("This mode %s requires WCS information.", exposure_type)
raise RuntimeError("The assign_wcs step has not been run.") from None
for k, ifu_wcs in enumerate(list_of_wcs):
# example: bounding_box = ((1600.5, 2048.5), # X
# (1886.5, 1925.5)) # Y
truncated = False
try:
xstart = ifu_wcs.bounding_box[0][0]
xstop = ifu_wcs.bounding_box[0][1]
ystart = ifu_wcs.bounding_box[1][0]
ystop = ifu_wcs.bounding_box[1][1]
log.debug("Using ifu_wcs.bounding_box.")
except AttributeError:
log.info("ifu_wcs.bounding_box not found; using domain instead.")
xstart = ifu_wcs.domain[0]["lower"]
xstop = ifu_wcs.domain[0]["upper"]
ystart = ifu_wcs.domain[1]["lower"]
ystop = ifu_wcs.domain[1]["upper"]
if xstart < -0.5:
truncated = True
log.info("xstart from WCS bounding_box was %g", xstart)
xstart = 0.0
if ystart < -0.5:
truncated = True
log.info("ystart from WCS bounding_box was %g", ystart)
ystart = 0.0
if xstop > 2047.5:
truncated = True
log.info("xstop from WCS bounding_box was %g", xstop)
xstop = 2047.0
if ystop > 2047.5:
truncated = True
log.info("ystop from WCS bounding_box was %g", ystop)
ystop = 2047.0
if truncated:
log.info(
"WCS bounding_box for stripe %d extended beyond image "
"edges, has been truncated to ...",
k,
)
log.info(" xstart=%g, xstop=%g, ystart=%g, ystop=%g", xstart, xstop, ystart, ystop)
# Convert these to integers, and add one to the upper limits,
# because we want to use these as slice limits.
xstart = int(math.ceil(xstart))
xstop = int(math.floor(xstop)) + 1
ystart = int(math.ceil(ystart))
ystop = int(math.floor(ystop)) + 1
dx = xstop - xstart
dy = ystop - ystart
ind = np.indices((dy, dx))
x = ind[1] + xstart
y = ind[0] + ystart
coords = ifu_wcs(x, y)
wl = coords[2]
nan_flag = np.isnan(wl)
good_flag = np.logical_not(nan_flag)
if wl[good_flag].max() < MICRONS_100:
log.warning("Wavelengths in WCS table appear to be in meters")
# Set NaNs to a relatively harmless value, but don't modify nan_flag.
wl[nan_flag] = 0.0
flat_2d, flat_dq_2d, flat_err_2d = create_flat_field(
wl,
f_flat_model,
s_flat_model,
d_flat_model,
xstart,
xstop,
ystart,
ystop,
exposure_type,
dispaxis,
None,
None,
)
mask = flat_2d <= 0.0
nbad = mask.sum(dtype=np.intp)
if nbad > 0:
log.debug("%d flat-field values <= 0", nbad)
flat_2d[mask] = np.nan
flat_dq_2d[mask] = np.bitwise_or(flat_dq_2d[mask], BADFLAT)
del mask
flat[ystart:ystop, xstart:xstop][good_flag] = flat_2d[good_flag]
if flat_dq.dtype == flat_dq_2d.dtype:
flat_dq[ystart:ystop, xstart:xstop][good_flag] |= flat_dq_2d[good_flag]
else:
log.warning(f"flat_dq.dtype = {flat_dq.dtype} flat_dq_2d.dtype = {flat_dq_2d.dtype}")
flat_dq[ystart:ystop, xstart:xstop][good_flag] |= flat_dq_2d[good_flag].astype(
flat_dq.dtype
)
flat_err[ystart:ystop, xstart:xstop][good_flag] = flat_err_2d[good_flag]
del nan_flag, good_flag
any_updated = True
# Ensure consistency between NaN-valued pixels and the DO_NOT_USE flag
indx = np.where((flat_dq & dqflags.pixel["DO_NOT_USE"]) != 0)
flat[indx] = np.nan
flat_err[indx] = np.nan
indx = np.where(~np.isfinite(flat))
flat_dq[indx] = flat_dq[indx] | dqflags.pixel["DO_NOT_USE"]
return flat, flat_dq, flat_err, any_updated
[docs]
def flat_for_nirspec_brightobj(output_model, f_flat_model, s_flat_model, d_flat_model, dispaxis):
"""
Create the interpolated flat for NIRSpec IFU.
Parameters
----------
output_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Science data model, modified (flat fielded) in-place.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
Returns
-------
flat : `~stdatamodels.jwst.datamodels.ImageModel`
The interpolated flat correction.
"""
exposure_type = output_model.meta.exposure.type
got_wcs = getattr(output_model.meta, "wcs", None) is not None
# Create an output model for the interpolated flat fields.
interpolated_flats = datamodels.ImageModel()
interpolated_flats.update(output_model, only="PRIMARY")
if got_wcs:
interpolated_flats.meta.wcs = output_model.meta.wcs
slit_name = output_model.name
# The input may be either 2-D or 3-D; save shape for use later.
shape = output_model.data.shape
ysize, xsize = shape[-2:]
# pixels with respect to the original image
xstart = output_model.meta.subarray.xstart - 1 + output_model.xstart - 1
ystart = output_model.meta.subarray.ystart - 1 + output_model.ystart - 1
xstop = xstart + xsize
ystop = ystart + ysize
# The wavelength of each pixel in a plane of the data.
got_wl_attribute = True
try:
wl = output_model.wavelength.copy() # a 2-D array
except AttributeError:
wl = []
got_wl_attribute = False
if not got_wl_attribute or len(wl) == 0:
got_wl_attribute = False
# There must be either a wavelength array or a meta.wcs.
if not got_wl_attribute or np.nanmin(wl) == 0.0 and np.nanmax(wl) == 0.0:
log.warning("The wavelength array has not been populated,")
if got_wcs:
log.warning("so using wcs instead of the wavelength array.")
grid = np.indices((ysize, xsize), dtype=np.float64)
(ra, dec, wl) = output_model.meta.wcs(grid[1], grid[0])
del ra, dec, grid
else:
log.warning("and there is no 'wcs' attribute,")
if output_model.meta.cal_step.assign_wcs == "COMPLETE":
log.warning("assign_wcs has been run, however.")
else:
log.warning("likely because assign_wcs has not been run.")
log.error("Skipping flat_field.")
output_model.meta.cal_step.flat_field = "SKIPPED"
return None
else:
log.debug("Wavelengths are from the wavelength array.")
nan_mask = np.isnan(wl)
good_mask = np.logical_not(nan_mask)
sum_nan_mask = nan_mask.sum(dtype=np.intp)
sum_good_mask = good_mask.sum(dtype=np.intp)
if sum_nan_mask > 0:
log.debug(
"Number of NaNs in wavelength array = %d out of %d",
sum_nan_mask,
sum_nan_mask + sum_good_mask,
)
if sum_good_mask < 1:
log.warning("(all are NaN)")
# Replace NaNs with a relatively harmless but out-of-bounds value.
wl[nan_mask] = 0.0
# Combine the three flat fields. The same flat will be applied to
# each plane (integration) in the cube.
flat_2d, flat_dq_2d, flat_err_2d = create_flat_field(
wl,
f_flat_model,
s_flat_model,
d_flat_model,
xstart,
xstop,
ystart,
ystop,
exposure_type,
dispaxis,
slit_name,
None,
)
mask = flat_2d <= 0.0
nbad = mask.sum(dtype=np.intp)
if nbad > 0:
log.debug("%d flat-field values <= 0", nbad)
flat_2d[mask] = np.nan
flat_dq_2d[mask] = np.bitwise_or(flat_dq_2d[mask], BADFLAT)
del mask
flat_dq_2d = flat_dq_2d.astype(output_model.get_dtype("dq"))
interpolated_flats.data = flat_2d
interpolated_flats.dq = flat_dq_2d
interpolated_flats.err = flat_err_2d.astype(output_model.get_dtype("err"))
interpolated_flats.wavelength = wl
return interpolated_flats
[docs]
def flat_for_nirspec_slit(
slit,
f_flat_model,
s_flat_model,
d_flat_model,
dispaxis,
exposure_type,
slit_nt,
subarray,
use_wavecorr,
):
"""
Create the interpolated flat for NIRSpec slit data.
Parameters
----------
slit : `~stdatamodels.jwst.datamodels.SlitModel`
A slit to process.
f_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel`, \
`~stdatamodels.jwst.datamodels.NirspecQuadFlatModel`, \
or None
Flat field for the fore optics.
s_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the spectrograph.
d_flat_model : `~stdatamodels.jwst.datamodels.NirspecFlatModel` or None
Flat field for the detector.
dispaxis : int
1 means horizontal dispersion, 2 means vertical dispersion.
exposure_type : str
The exposure type
slit_nt : namedtuple or None
For MOS data (only), this is used to get the quadrant number and
the indices of the current shutter in the Y and X directions.
subarray : str
The subarray specification from ``JwstDataModel.meta.subarray``.
use_wavecorr : bool or None
Flag indicating whether or not to use the corrected wavelengths
provided (upstream) by the wavecorr step.
Returns
-------
flat : `~stdatamodels.jwst.datamodels.SlitModel`
The calculated flat.
"""
# Create flat and flat dq arrays with default values
flat_2d = np.ones_like(slit.data)
flat_dq_2d = np.zeros_like(slit.dq)
flat_err_2d = np.zeros_like(slit.err)
# pixels with respect to the original image
ysize, xsize = slit.data.shape[-2:]
xstart = slit.xstart - 1 + subarray.xstart - 1
ystart = slit.ystart - 1 + subarray.ystart - 1
xstop = xstart + xsize
ystop = ystart + ysize
got_wcs = getattr(slit.meta, "wcs", None) is not None
# Get the wavelength at each pixel in the extracted slit data.
# If the wavelength attribute exists and is populated, use it
# in preference to the wavelengths returned by the wcs function.
wl = wcs_utils.get_wavelengths(slit, use_wavecorr=use_wavecorr)
if wl is None:
# Create and return a placeholder flat as a placeholder, if necessary
placeholder_flat = datamodels.SlitModel(data=flat_2d, dq=flat_dq_2d, err=flat_err_2d)
placeholder_flat.name = slit.name
placeholder_flat.xstart = slit.xstart
placeholder_flat.xsize = slit.xsize
placeholder_flat.ystart = slit.ystart
placeholder_flat.ysize = slit.ysize
placeholder_flat.wavelength = np.zeros_like(slit.data)
return placeholder_flat
# We've got everything we need for the rest of processing
nan_mask = np.isnan(wl)
good_mask = np.logical_not(nan_mask)
sum_nan_mask = nan_mask.sum(dtype=np.intp)
sum_good_mask = good_mask.sum(dtype=np.intp)
if sum_nan_mask > 0:
log.debug(
f"Number of NaNs in sci wavelength array = {sum_nan_mask} "
f"out of {sum_nan_mask + sum_good_mask}"
)
if sum_good_mask < 1:
log.warning("(all are NaN)")
# Replace NaNs with a relatively harmless but out-of-bounds value.
wl[nan_mask] = 0.0
max_wavelength = np.nanmax(wl)
if 0.0 < max_wavelength < MICRONS_100:
log.warning("Wavelengths in science data appear to be in meters.")
# Combine the three flat fields for the current subarray.
flat_2d, flat_dq_2d, flat_err_2d = create_flat_field(
wl,
f_flat_model,
s_flat_model,
d_flat_model,
xstart,
xstop,
ystart,
ystop,
exposure_type,
dispaxis,
slit.name,
slit_nt,
)
# Mask bad flatfield values
mask = flat_2d <= 0.0
nbad = mask.sum(dtype=np.intp)
if nbad > 0:
log.debug("%d flat-field values <= 0", nbad)
flat_2d[mask] = np.nan
flat_dq_2d[mask] = np.bitwise_or(flat_dq_2d[mask], BADFLAT)
del mask
# Put the computed flat, flat_dq and flat_err into a datamodel
new_flat = datamodels.SlitModel(data=flat_2d, dq=flat_dq_2d, err=flat_err_2d)
new_flat.name = slit.name
new_flat.xstart = slit.xstart
new_flat.xsize = slit.xsize
new_flat.ystart = slit.ystart
new_flat.ysize = slit.ysize
new_flat.wavelength = wl.copy()
# Copy the WCS info from output (same as input).
if got_wcs:
new_flat.meta.wcs = slit.meta.wcs
return new_flat