import warnings
from functools import cached_property
from math import isclose
from typing import Sequence, Literal
import astropy.units as u
import numpy as np
from astropy.modeling import models, Model, fitting, CompoundModel
from matplotlib.pyplot import Axes, Figure, setp, subplots
from numpy.ma import MaskedArray
from numpy.typing import ArrayLike
from scipy import optimize, ndimage
from scipy.spatial import KDTree
from specreduce.calibration_data import load_pypeit_calibration_lines
from specreduce.compat import Spectrum
from specreduce.line_matching import find_arc_lines
__all__ = ["WavelengthCalibration1D"]
from specreduce.wavesol1d import WavelengthSolution1D
def _format_linelist(lst: ArrayLike) -> MaskedArray:
"""Force a line list into a MaskedArray with a shape of (n, 2) where n is the number of lines.
Parameters
----------
lst
Input array of centroids or centroids with amplitudes. Must be either:
- A 1D array with a shape [n] for centroids.
- A 2D array with a shape [n, 2] for centroids and amplitudes.
Returns
-------
numpy.ma.MaskedArray
Formatted and standardized line list array with shape [n, 2], where each row
contains a line centroid and amplitude.
Raises
------
ValueError
If the input line list does not meet the specified dimensional or shape
requirements.
"""
lst: MaskedArray = MaskedArray(lst, copy=True)
lst.mask = np.ma.getmaskarray(lst)
if lst.ndim > 2 or lst.ndim == 2 and lst.shape[1] > 2:
raise ValueError(
"Line lists must be 1D with a shape [n] (centroids) or "
"2D with a shape [n, 2] (centroids and amplitudes)."
)
if lst.ndim == 1:
lst = MaskedArray(np.tile(lst[:, None], [1, 2]))
lst[:, 1] = 0.0
lst.mask[:, :] = lst.mask.any(axis=1)[:, None]
return lst[np.argsort(lst.data[:, 0])]
def _unclutter_text_boxes(labels: Sequence) -> None:
"""Remove overlapping labels from the plot.
Removes overlapping text labels from a set of matplotlib label objects. The function iterates
over all combinations of labels, checks for overlaps among their bounding boxes, and removes
the label with the lower z-order in case of an overlap.
Parameters
----------
labels
A list of matplotlib.text.Text objects.
"""
to_remove = set()
for i in range(len(labels)):
for j in range(i + 1, len(labels)):
l1 = labels[i]
l2 = labels[j]
bbox1 = l1.get_window_extent()
bbox2 = l2.get_window_extent()
if bbox1.overlaps(bbox2):
if l1.zorder < l2.zorder:
to_remove.add(l1)
else:
to_remove.add(l2)
for label in to_remove:
label.remove()
[docs]
class WavelengthCalibration1D:
def __init__(
self,
arc_spectra: Spectrum | Sequence[Spectrum] | None = None,
obs_lines: ArrayLike | Sequence[ArrayLike] | None = None,
line_lists: ArrayLike | None = None,
unit: u.Unit = u.angstrom,
ref_pixel: float | None = None,
pix_bounds: tuple[int, int] | None = None,
line_list_bounds: None | tuple[float, float] = None,
n_strogest_lines: None | int = None,
wave_air: bool = False,
) -> None:
"""A class for wavelength calibration of one-dimensional spectra.
This class is designed to facilitate wavelength calibration of one-dimensional spectra,
with support for both direct input of line lists and observed spectra. It uses a polynomial
model for fitting the wavelength solution and offers features to incorporate catalog lines
and observed line positions.
Parameters
----------
arc_spectra
Arc spectra provided as ``Spectrum`` objects for wavelength fitting, by default
None. This parameter and ``obs_lines`` cannot be provided simultaneously.
obs_lines
Pixel positions of observed spectral lines for wavelength fitting, by default None. This
parameter and ``arc_spectra`` cannot be provided simultaneously.
line_lists
Catalogs of spectral line wavelengths for wavelength calibration. Provide either an
array of line wavelengths or a list of `PypeIt <https://github.com/pypeit/PypeIt>`_
catalog names. If `None`, no line lists are used. You can query the list of available
catalog names via `~specreduce.calibration_data.get_available_line_catalogs`.
unit
The unit of the wavelength calibration, by default ``astropy.units.Angstrom``.
ref_pixel
The reference pixel in which the wavelength solution will be centered.
pix_bounds
Lower and upper pixel bounds for fitting, defined as a tuple (min, max). If
``obs_lines`` is provided, this parameter is mandatory.
line_list_bounds
Wavelength bounds as a tuple (min, max) for filtering usable spectral
lines from the provided line lists.
n_strogest_lines
The number of strongest lines to be included from the line lists. If `None`, all
are included.
wave_air
Boolean indicating whether the input wavelengths correspond to air rather than vacuum;
by default `False`, meaning vacuum wavelengths.
"""
self.unit = unit
self._unit_str = unit.to_string("latex")
self.degree = None
self.ref_pixel = ref_pixel
self.nframes = 0
if ref_pixel is not None and ref_pixel < 0:
raise ValueError("Reference pixel must be positive.")
self.arc_spectra: list[Spectrum] | None = None
self.bounds_pix: tuple[int, int] | None = pix_bounds
self.bounds_wav: tuple[float, float] | None = None
self._cat_lines: list[MaskedArray] | None = None
self._obs_lines: list[MaskedArray] | None = None
self._trees: list[KDTree] | None = None
self._fit: optimize.OptimizeResult | None = None
self.solution = WavelengthSolution1D(None, pix_bounds, unit)
# Read and store the observational data if given. The user can provide either a list of arc
# spectra as Spectrum objects or a list of line pixel position arrays. An attempt to give
# both raises an error.
if arc_spectra is not None and obs_lines is not None:
raise ValueError("Only one of arc_spectra or obs_lines can be provided.")
if arc_spectra is not None:
self.arc_spectra = [arc_spectra] if isinstance(arc_spectra, Spectrum) else arc_spectra
self.nframes = len(self.arc_spectra)
for s in self.arc_spectra:
if s.data.ndim > 1:
raise ValueError("The arc spectrum must be one dimensional.")
if len(set([s.data.size for s in self.arc_spectra])) != 1:
raise ValueError("All arc spectra must have the same length.")
self.bounds_pix = (0, self.arc_spectra[0].shape[0])
self.solution.bounds_pix = self.bounds_pix
if self.ref_pixel is None:
self.ref_pixel = self.arc_spectra[0].data.size / 2
elif obs_lines is not None:
self.observed_lines = obs_lines
self.nframes = len(self._obs_lines)
if self.bounds_pix is None:
raise ValueError("Must give pixel bounds when providing observed line positions.")
if self.ref_pixel is None:
raise ValueError("Must give reference pixel when providing observed lines.")
# Read the line lists if given. The user can provide an array of line wavelength positions
# or a list of line list names (used by `load_pypeit_calibration_lines`) for each arc
# spectrum.
if line_lists is not None:
if not isinstance(line_lists, (tuple, list)):
line_lists = [line_lists]
if len(line_lists) != self.nframes:
raise ValueError("The number of line lists must match the number of arc spectra.")
self._read_linelists(
line_lists,
line_list_bounds=line_list_bounds,
wave_air=wave_air,
n_strongest=n_strogest_lines,
)
def _line_match_distance(self, x: ArrayLike, model: Model, max_distance: float = 100) -> float:
"""Compute the sum of distances between catalog lines and transformed observed lines.
This function evaluates the pixel-to-wavelength model at the observed line positions,
queries the nearest catalog line (via KDTree), and sums the distances after clipping
them at `max_distance`. The result is suitable as a scalar objective for global
optimization of the wavelength solution.
Parameters
----------
x
Pixel-to-wavelength model parameters (e.g., Polynomial1D coefficients c0..cN).
model
The pixel-to-wavelength model to be evaluated.
max_distance
Upper bound used to clip individual distances before summation.
Returns
-------
float
Sum of nearest-neighbor distances between transformed observed lines and catalog lines.
"""
total_distance = 0.0
for t, l in zip(self._trees, self.observed_line_locations):
transformed_lines = model.evaluate(l, -self.ref_pixel, *x)[:, None]
total_distance += np.clip(t.query(transformed_lines)[0], 0, max_distance).sum()
return total_distance
def _read_linelists(
self,
line_lists: Sequence,
line_list_bounds: None | tuple[float, float] = None,
wave_air: bool = False,
n_strongest: None | int = None,
) -> None:
"""Read and processes line lists.
Parameters
----------
line_lists
A collection of line lists that can either be arrays of wavelengths or `PypeIt
<https://github.com/pypeit/PypeIt>`_
lamp names. You can query the list of available catalog names via
`~specreduce.calibration_data.get_available_line_catalogs`.
line_list_bounds
A tuple specifying the minimum and maximum wavelength bounds. Only wavelengths
within this range are retained.
wave_air
If True, convert the vacuum wavelengths used by `PypeIt
<https://github.com/pypeit/PypeIt>`_ to air wavelengths.
n_strongest
The number of strongest lines to be used. If `None`, all lines are used.
"""
lines = []
for lst in line_lists:
if isinstance(lst, np.ndarray):
lines.append(lst)
else:
if isinstance(lst, str):
lst = [lst]
lines.append([])
for ll in lst:
line_table = load_pypeit_calibration_lines(ll, wave_air=wave_air)
if n_strongest is not None:
ix = np.argsort(line_table["amplitude"].value)[::-1]
lines[-1].append(line_table[ix][:n_strongest]["wavelength"].to(
self.unit).value)
else:
lines[-1].append(line_table["wavelength"].to(self.unit).value)
lines[-1] = np.ma.masked_array(np.sort(np.concatenate(lines[-1])))
if line_list_bounds is not None:
for i, lst in enumerate(lines):
lines[i] = lst[(lst >= line_list_bounds[0]) & (lst <= line_list_bounds[1])]
self.catalog_lines = lines
self._create_trees()
def _create_trees(self) -> None:
"""Initialize the KDTree instances for the current set of catalog line locations."""
self._trees = [KDTree(lst.compressed()[:, None]) for lst in self.catalog_line_locations]
[docs]
def find_lines(self, fwhm: float, noise_factor: float = 1.0) -> None:
"""Find lines in the provided arc spectra.
Determines the spectral lines within each spectrum of the arc spectra based on the
provided initial guess for the line Full Width at Half Maximum (FWHM).
Parameters
----------
fwhm
Initial guess for the FWHM for the spectral lines, used as a parameter in
the ``find_arc_lines`` function to locate and identify spectral arc lines.
noise_factor
The factor to multiply the uncertainty by to determine the noise threshold
in the `~specutils.fitting.find_lines_threshold` routine.
"""
if self.arc_spectra is None:
raise ValueError("Must provide arc spectra to find lines.")
line_lists = []
for i, arc in enumerate(self.arc_spectra):
lines = find_arc_lines(arc, fwhm, noise_factor=noise_factor)
ix = np.round(lines["centroid"].value).astype(int)
if np.any((ix < 0) | (ix >= arc.shape[0])):
raise ValueError(
"Error in arc line identification. Try increasing ``noise_factor``."
)
amplitudes = ndimage.maximum_filter1d(arc.flux.value, 5)[ix]
line_lists.append(
np.ma.masked_array(np.transpose([lines["centroid"].value, amplitudes]))
)
self.observed_lines = line_lists
def _create_model(self, degree: int, coeffs: None | ArrayLike = None) -> CompoundModel:
"""Initialize the polynomial model with the given degree and an optional base model.
This method sets up a polynomial transformation based on the reference pixel and degree.
If coefficients are provided, they are copied to the initialized model up to the degree
specified.
Parameters
----------
degree
Degree of the polynomial model to be initialized.
coeffs
Optional initial polynomial coefficients.
"""
self.degree = degree
pars = {}
if coeffs is not None:
nc = min(degree + 1, len(coeffs))
pars = {f"c{i}": c for i, c in enumerate(coeffs[:nc])}
return models.Shift(-self.ref_pixel) | models.Polynomial1D(self.degree, **pars)
[docs]
def fit_lines(
self,
pixels: ArrayLike,
wavelengths: ArrayLike,
degree: int = 3,
match_obs: bool = False,
match_cat: bool = False,
refine_fit: bool = True,
refine_max_distance: float = 5.0,
refined_fit_degree: int | None = None,
) -> WavelengthSolution1D:
"""Fit the pixel-to-wavelength model using provided line pairs.
This method fits the pixel-to-wavelength transformation using explicitly provided pairs
of pixel coordinates and their corresponding wavelengths via a linear least-squares fit
Optionally, the provided pixel and wavelength values can be "snapped" to the nearest
values present in the internally stored observed line list and catalog line list,
respectively. This allows the inputs to be approximate, as the snapping step selects
the nearest precise centroids and catalog values when available.
Parameters
----------
pixels
An array of pixel positions corresponding to known spectral lines.
wavelengths
An array of the same size as ``pixels``, containing the known
wavelengths corresponding to the given pixel positions.
degree
The polynomial degree for the wavelength solution.
match_obs
If True, snap the input ``pixels`` values to the nearest
pixel values found in ``self.observed_line_locations`` (if available). This helps
ensure the fit uses the precise centroids detected by `find_lines`
or provided initially.
match_cat
If True, snap the input ``wavelengths`` values to the
nearest wavelength values found in ``self.catalog_line_locations`` (if available).
This ensures the fit uses the precise catalog wavelengths.
refine_fit
If True (default), automatically call the ``refine_fit`` method
immediately after the global optimization to improve the solution
using a least-squares fit on matched lines.
refine_max_distance
Maximum allowed separation between catalog and observed lines for them to
be considered a match during ``refine_fit``. Ignored if ``refine_fit`` is False.
refined_fit_degree
The polynomial degree for the refined fit. Can be higher than ``degree``. If ``None``,
equals to ``degree``.
"""
pixels = np.asarray(pixels)
wavelengths = np.asarray(wavelengths)
if pixels.size != wavelengths.size:
raise ValueError("The sizes of pixel and wavelength arrays must match.")
nlines = pixels.size
if nlines < 2:
raise ValueError("Need at least two lines for a fit")
if self.bounds_pix is None:
raise ValueError("Cannot fit without pixel bounds set.")
# Match the input wavelengths to catalog lines.
if match_cat:
if self._cat_lines is None:
raise ValueError("Cannot fit without catalog lines set.")
tree = KDTree(
np.concatenate([c.compressed() for c in self.catalog_line_locations])[:, None]
)
ix = tree.query(wavelengths[:, None])[1]
wavelengths = tree.data[ix][:, 0]
# Match the input pixel values to observed pixel values.
if match_obs:
if self._obs_lines is None:
raise ValueError("Cannot fit without observed lines set.")
tree = KDTree(
np.concatenate([c.compressed() for c in self.observed_line_locations])[:, None]
)
ix = tree.query(pixels[:, None])[1]
pixels = tree.data[ix][:, 0]
fitter = fitting.LinearLSQFitter()
shift, model = self._create_model(degree)
if model.degree > nlines:
warnings.warn(
"The degree of the polynomial model is higher than the number of lines. "
"Fixing the higher-order coefficients to zero."
)
for i in range(nlines, model.degree + 1):
model.fixed[f"c{i}"] = True
model = fitter(model, pixels - self.ref_pixel, wavelengths)
for i in range(model.degree + 1):
model.fixed[f"c{i}"] = False
self.solution.p2w = shift | model
can_match = self._cat_lines is not None and self._obs_lines is not None
if refine_fit and can_match:
self.refine_fit(refined_fit_degree, max_match_distance=refine_max_distance)
else:
if can_match:
self.match_lines()
return self.solution
[docs]
def fit_dispersion(
self,
wavelength_bounds: tuple[float, float],
dispersion_bounds: tuple[float, float],
higher_order_limits: Sequence[float] | None = None,
degree: int = 3,
popsize: int = 30,
max_distance: float = 100,
refine_fit: bool = True,
refine_max_distance: float = 5.0,
refined_fit_degree: int | None = None,
) -> WavelengthSolution1D:
"""Calculate a wavelength solution using all the catalog and observed lines.
This method estimates a wavelength solution without pre-matched pixel–wavelength
pairs, making it suitable for automated pipelines on stable, well-characterized
spectrographs. It uses differential evolution to optimize the polynomial parameters
that minimize the distance between the predicted wavelengths of the observed lines
and their nearest catalog lines. The resulting solution can optionally be refined
with a least-squares fit to automatically matched lines.
Parameters
----------
wavelength_bounds
(min, max) bounds for the wavelength at ``ref_pixel``; used as an optimization
constraint.
dispersion_bounds
(min, max) bounds for the dispersion d(wavelength)/d(pixel) at ``ref_pixel``; used
as an optimization constraint.
higher_order_limits
Absolute limits for the higher-order polynomial coefficients. Each coefficient is
constrained to [-limit, limit]. If provided, the number of limits must equal
(polynomial degree - 1).
degree
The polynomial degree for the wavelength solution.
popsize
Population size for ``scipy.optimize.differential_evolution``. Larger values can
improve the chance of finding the global minimum at the cost of additional time.
max_distance
Maximum wavelength separation used when associating observed and catalog lines in
the optimization. Distances larger than this threshold are clipped to this value
in the cost function to limit the impact of outliers.
refine_fit
If True (default), call ``refine_fit`` after global optimization to improve the
solution using a least-squares fit on matched lines.
refine_max_distance
Maximum allowed separation between catalog and observed lines for them to
be considered a match during ``refine_fit``. Ignored if ``refine_fit`` is False.
refined_fit_degree
The polynomial degree for the refined fit. Can be higher than ``degree``. If ``None``,
equals to ``degree``.
"""
# Define bounds for differential_evolution.
bounds = [np.asarray(wavelength_bounds), np.asarray(dispersion_bounds)]
model = self._create_model(degree)
if higher_order_limits is not None:
if len(higher_order_limits) != model[1].degree - 1:
raise ValueError(
"The number of higher-order limits must match the degree of the polynomial "
"model minus one."
)
for v in higher_order_limits:
bounds.append(np.asarray([-v, v]))
else:
for i in range(2, model[1].degree + 1):
bounds.append(
np.array([-1, 1]) * 10 ** (np.log10(np.mean(dispersion_bounds)) - 2 * i)
)
bounds = np.array(bounds)
self._fit = optimize.differential_evolution(
lambda x: self._line_match_distance(x, model, max_distance),
bounds=bounds,
popsize=popsize,
init="sobol",
)
self.solution.p2w = self._create_model(degree, coeffs=self._fit.x)
can_match = self._cat_lines is not None and self._obs_lines is not None
if refine_fit:
self.refine_fit(refined_fit_degree, max_match_distance=refine_max_distance)
else:
if can_match:
self.match_lines()
return self.solution
[docs]
def refine_fit(
self, degree: None | int = None, max_match_distance: float = 5.0, max_iter: int = 5
) -> WavelengthSolution1D:
"""Refine the pixel-to-wavelength transformation fit.
Fits (or re-fits) the polynomial wavelength solution using the currently
matched pixel–wavelength pairs. Optionally adjusts the polynomial degree,
filters matches by a maximum pixel-space separation, and iterates the fit.
Parameters
----------
degree
The polynomial degree for the wavelength solution. If ``None``, the degree
previously set by the `~WavelengthCalibration1D.fit_lines` or
`~WavelengthCalibration1D.fit_dispersion` method will be used.
max_match_distance
Maximum allowable distance used to identify matched pixel and wavelength
data points. Points exceeding the bound will not be considered in the fit.
max_iter
Maximum number of fitting iterations.
"""
# Create a new model with the current parameters if degree is specified.
if degree is not None and degree != self.degree:
model = self._create_model(degree, coeffs=self.solution.p2w[1].parameters)
else:
model = self.solution.p2w
shift, poly = model
fitter = fitting.LinearLSQFitter()
rms = np.nan
for i in range(max_iter):
self.match_lines(max_match_distance)
matched_pix = np.ma.concatenate(self.observed_line_locations).compressed()
matched_wav = np.ma.concatenate(self.catalog_line_locations).compressed()
rms_new = np.sqrt(((matched_wav - model(matched_pix)) ** 2).mean())
if isclose(rms_new, rms):
break
model = shift | fitter(poly, matched_pix - self.ref_pixel, matched_wav)
rms = rms_new
self.solution.p2w = model
return self.solution
@property
def degree(self) -> None | int:
return self._degree
@degree.setter
def degree(self, degree: int | None):
if degree is not None and degree < 1:
raise ValueError("Degree must be at least 1.")
self._degree = degree
@property
def observed_lines(self) -> None | list[MaskedArray]:
"""Pixel positions and amplitudes of the observed lines as a list of masked arrays."""
return self._obs_lines
[docs]
@cached_property
def observed_line_locations(self) -> None | list[MaskedArray]:
"""Pixel positions of the observed lines as a list of masked arrays."""
if self._obs_lines is None:
return None
else:
return [line[:, 0] for line in self._obs_lines]
[docs]
@cached_property
def observed_line_amplitudes(self) -> None | list[MaskedArray]:
"""Amplitudes of the observed lines as a list of masked arrays."""
if self._obs_lines is None:
return None
else:
return [line[:, 1] for line in self._obs_lines]
@observed_lines.setter
def observed_lines(self, line_lists: ArrayLike | list[ArrayLike]):
if not isinstance(line_lists, Sequence):
line_lists = [line_lists]
self._obs_lines = []
for lst in line_lists:
self._obs_lines.append(_format_linelist(lst))
if hasattr(self, "observed_line_locations"):
del self.observed_line_locations
if hasattr(self, "observed_line_amplitudes"):
del self.observed_line_amplitudes
@property
def catalog_lines(self) -> None | list[MaskedArray]:
"""Catalog line wavelengths as a list of masked arrays."""
return self._cat_lines
[docs]
@cached_property
def catalog_line_locations(self) -> None | list[MaskedArray]:
"""Pixel positions of the catalog lines as a list of masked arrays."""
if self._cat_lines is None:
return None
else:
return [line[:, 0] for line in self._cat_lines]
[docs]
@cached_property
def catalog_line_amplitudes(self) -> None | list[MaskedArray]:
"""Amplitudes of the catalog lines as a list of masked arrays."""
if self._obs_lines is None:
return None
else:
return [line[:, 1] for line in self._cat_lines]
@catalog_lines.setter
def catalog_lines(self, line_lists: ArrayLike | list[ArrayLike]):
if not isinstance(line_lists, Sequence):
line_lists = [line_lists]
self._cat_lines = []
for lst in line_lists:
self._cat_lines.append(_format_linelist(lst))
if hasattr(self, "catalog_line_locations"):
del self.catalog_line_locations
if hasattr(self, "catalog_line_amplitudes"):
del self.catalog_line_amplitudes
[docs]
def match_lines(self, max_distance: float = 5) -> None:
"""Match the observed lines to theoretical lines.
Parameters
----------
max_distance
The maximum allowed distance between the catalog and observed lines for them to be
considered a match.
"""
for iframe, tree in enumerate(self._trees):
l, ix = tree.query(
self.solution.p2w(self.observed_line_locations[iframe].data)[:, None],
distance_upper_bound=max_distance,
)
m = np.isfinite(l)
# Check for observed lines that match a catalog line.
# Remove all but the nearest match. This isn't an optimal solution,
# we could also iterate the match by removing the currently matched
# lines, but this works for now.
uix, cnt = np.unique(ix[m], return_counts=True)
if any(n := cnt > 1):
for i, c in zip(uix[n], cnt[n]):
s = ix == i
r = np.zeros(c, dtype=bool)
r[np.argmin(l[s])] = True
m[s] = r
self._cat_lines[iframe].mask[:, :] = True
self._cat_lines[iframe].mask[ix[m], :] = False
self._obs_lines[iframe].mask[:, :] = ~m[:, None]
[docs]
def remove_unmatched_lines(self) -> None:
"""Remove unmatched lines from observation and catalog line data."""
self.observed_lines = [lst.compressed().reshape([-1, 2]) for lst in self._obs_lines]
self.catalog_lines = [lst.compressed().reshape([-1, 2]) for lst in self._cat_lines]
self._create_trees()
[docs]
def rms(self, space: Literal["pixel", "wavelength"] = "wavelength") -> float:
"""Compute the RMS of the residuals between matched lines in the pixel or wavelength space.
Parameters
----------
space
The space in which to calculate the RMS residual. If 'wavelength',
the calculation is performed in the wavelength space. If 'pixel',
it is performed in the pixel space. Default is 'wavelength'.
Returns
-------
float
"""
self.match_lines()
mpix = np.ma.concatenate(self.observed_line_locations).compressed()
mwav = np.ma.concatenate(self.catalog_line_locations).compressed()
if space == "wavelength":
return np.sqrt(((mwav - self.solution.p2w(mpix)) ** 2).mean())
elif space == "pixel":
return np.sqrt(((mpix - self.solution.w2p(mwav)) ** 2).mean())
else:
raise ValueError("Space must be either 'pixel' or 'wavelength'")
def _plot_lines(
self,
kind: Literal["observed", "catalog"],
frames: int | Sequence[int] | None = None,
axes: Axes | Sequence[Axes] | None = None,
figsize: tuple[float, float] | None = None,
plot_labels: bool | Sequence[bool] = True,
map_x: bool = False,
label_kwargs: dict | None = None,
) -> Figure:
"""
Plot lines with optional features such as wavelength mapping and label customization.
Parameters
----------
kind
Specifies the line list to plot.
frames
Frame indices to plot. If None, all frames are plotted.
axes
Axes object(s) where the lines should be plotted. If None, new Axes are generated.
figsize
Size of the figure to use if creating new Axes. Ignored if axes are provided.
plot_labels
Flag(s) indicating whether to display labels for the lines. If a single value is
provided, it is applied to all frames.
map_x
If True, maps the x-axis values between pixel and wavelength space.
label_kwargs
Additional keyword arguments to customize the label style.
Returns
-------
Figure
The Figure object containing the plotted spectral lines.
"""
largs = dict(backgroundcolor="w", rotation=90, size="small")
if label_kwargs is not None:
largs.update(label_kwargs)
if frames is None:
frames = np.arange(self.nframes)
else:
frames = np.atleast_1d(frames)
if axes is None:
fig, axes = subplots(
frames.size, 1, figsize=figsize, constrained_layout=True, sharex="all"
)
elif isinstance(axes, Axes):
fig = axes.figure
axes = [axes]
else:
fig = axes[0].figure
axes = np.atleast_1d(axes)
if isinstance(plot_labels, bool):
plot_labels = np.full(frames.size, plot_labels, dtype=bool)
if map_x and self.solution.p2w is None:
raise ValueError("Cannot map between pixels and wavelengths without a fitted model.")
if kind == "observed":
transform = self.solution.pix_to_wav if map_x else lambda x: x
linelists = self.observed_lines
spectra = self.arc_spectra
lc = "C0"
else:
transform = self.solution.wav_to_pix if map_x else lambda x: x
linelists = self.catalog_lines
spectra = None
lc = "C1"
ypad = 1.3
labels = []
for iframe, (ax, frame) in enumerate(zip(axes, frames)):
if spectra is not None:
spc = self.arc_spectra[iframe]
vmax = np.nanmax(spc.flux.value)
ax.plot(transform(spc.spectral_axis.value), spc.flux.value / vmax, "k")
else:
vmax = 1.0
if linelists is not None:
labels.append([])
# Loop over individual lines in the line list.
for i in range(linelists[iframe].shape[0]):
c, a = linelists[iframe].data[i]
ls = "-" if linelists[iframe].mask[i, 0] == 0 else ":"
ax.plot(transform([c, c]), [a / vmax + 0.1, 1.27], c=lc, ls=ls, zorder=-100)
if plot_labels[iframe]:
lloc = transform(c)
labels[-1].append(
ax.text(
lloc,
ypad,
np.round(lloc, 4 - 1 - int(np.floor(np.log10(lloc)))),
ha="center",
va="top",
**largs,
)
)
labels[-1][-1].set_clip_on(True)
labels[-1][-1].zorder = a
if (kind == "observed" and not map_x) or (kind == "catalog" and map_x):
xlabel = "Pixel"
else:
xlabel = f"Wavelength {self._unit_str}"
if kind == "catalog":
axes[0].xaxis.set_label_position("top")
axes[0].xaxis.tick_top()
setp(axes[0], xlabel=xlabel)
for ax in axes[1:]:
ax.set_xticklabels([])
else:
setp(axes[-1], xlabel=xlabel)
for ax in axes[:-1]:
ax.set_xticklabels([])
xlims = np.array([ax.get_xlim() for ax in axes])
setp(axes, xlim=(xlims[:, 0].min(), xlims[:, 1].max()), yticks=[])
if linelists is not None:
fig.canvas.draw()
for i in range(len(frames)):
if plot_labels[i]:
# Calculate the label bounding box upper limits and adjust the y-axis limits.
tr_to_data = axes[i].transData.inverted()
ymax = -np.inf
for lb in labels[i]:
ymax = max(ymax, tr_to_data.transform(lb.get_window_extent().p1)[1])
setp(axes[i], ylim=(-0.04, ymax * 1.06))
# Remove the overlapping labels prioritizing the high-amplitude lines.
_unclutter_text_boxes(labels[i])
return fig
[docs]
def plot_catalog_lines(
self,
frames: int | Sequence[int] | None = None,
axes: Axes | Sequence[Axes] | None = None,
figsize: tuple[float, float] | None = None,
plot_labels: bool | Sequence[bool] = True,
map_to_pix: bool = False,
label_kwargs: dict | None = None,
) -> Figure:
"""Plot the catalog lines.
Parameters
----------
frames
Specifies the frames to be plotted. If an integer, only one frame is plotted.
If a sequence, the specified frames are plotted. If None, default selection
or all frames are plotted.
axes
The matplotlib axes where catalog data will be plotted. If provided, the function
will plot on these axes. If None, new axes will be created.
figsize
Specifies the dimensions of the figure as (width, height). If None, the default
dimensions are used.
plot_labels
If True, the numerical values associated with the catalog data will be displayed
in the plot. If False, only the graphical representation of the lines will be shown.
map_to_pix
Indicates whether the catalog data should be mapped to pixel coordinates
before plotting. If True, the data is converted to pixel coordinates.
label_kwargs
Specifies the keyword arguments for the line label text objects.
Returns
-------
Figure
The matplotlib figure containing the plotted catalog lines.
"""
return self._plot_lines(
"catalog",
frames=frames,
axes=axes,
figsize=figsize,
plot_labels=plot_labels,
map_x=map_to_pix,
label_kwargs=label_kwargs,
)
[docs]
def plot_observed_lines(
self,
frames: int | Sequence[int] | None = None,
axes: Axes | Sequence[Axes] | None = None,
figsize: tuple[float, float] | None = None,
plot_labels: bool | Sequence[bool] = True,
map_to_wav: bool = False,
label_kwargs: dict | None = None,
) -> Figure:
"""Plot observed spectral lines for the given arc spectra.
Parameters
----------
frames
Specifies the frame(s) for which the plot is to be generated. If None, all frames
are plotted. When an integer is provided, a single frame is used. For a sequence
of integers, multiple frames are plotted.
axes
Axes object(s) to plot the spectral lines on. If None, new axes are created.
figsize
Dimensions of the figure to be created, specified as a tuple (width, height). Ignored
if ``axes`` is provided.
plot_labels
If True, plots the numerical values of the observed lines at their respective
locations on the graph.
map_to_wav
Determines whether to map the x-axis values to wavelengths.
label_kwargs
Specifies the keyword arguments for the line label text objects.
Returns
-------
Figure
The matplotlib figure containing the observed lines plot.
"""
fig = self._plot_lines(
"observed",
frames=frames,
axes=axes,
figsize=figsize,
plot_labels=plot_labels,
map_x=map_to_wav,
label_kwargs=label_kwargs,
)
for ax in fig.axes:
ax.autoscale(True, "x", tight=True)
return fig
[docs]
def plot_fit(
self,
frames: Sequence[int] | int | None = None,
figsize: tuple[float, float] | None = None,
plot_labels: bool = True,
obs_to_wav: bool = False,
cat_to_pix: bool = False,
label_kwargs: dict | None = None,
) -> Figure:
"""Plot the fitted catalog and observed lines for the specified arc spectra.
Parameters
----------
frames
The indices of the frames to plot. If `None`, all frames from 0 to
``self.nframes - 1`` are plotted.
figsize
Defines the width and height of the figure in inches. If `None`, the
default size is used.
plot_labels
If `True`, print line locations over the plotted lines. Can also be a list with
the same length as ``frames``.
obs_to_wav
If `True`, transform the x-axis of observed lines to the wavelength domain
using `self._p2w`, if available.
cat_to_pix
If `True`, transforms catalog data points to pixel values before plotting.
label_kwargs
Specifies the keyword arguments for the line label text objects.
Returns
-------
matplotlib.figure.Figure
The figure object containing the generated subplots.
"""
if frames is None:
frames = np.arange(self.nframes)
else:
frames = np.atleast_1d(frames)
fig, axs = subplots(2 * frames.size, 1, constrained_layout=True, figsize=figsize)
self.plot_catalog_lines(
frames,
axs[0::2],
plot_labels=plot_labels,
map_to_pix=cat_to_pix,
label_kwargs=label_kwargs,
)
self.plot_observed_lines(
frames,
axs[1::2],
plot_labels=plot_labels,
map_to_wav=obs_to_wav,
label_kwargs=label_kwargs,
)
xlims = np.array([ax.get_xlim() for ax in axs[::2]])
if obs_to_wav:
setp(axs, xlim=(xlims[:, 0].min(), xlims[:, 1].max()))
else:
setp(axs[::2], xlim=(xlims[:, 0].min(), xlims[:, 1].max()))
setp(axs[0], yticks=[], xlabel=f"Wavelength [{self._unit_str}]")
for ax in axs[1:-1]:
ax.set_xlabel("")
ax.set_xticklabels("")
axs[0].xaxis.set_label_position("top")
axs[0].xaxis.tick_top()
return fig
[docs]
def plot_residuals(
self,
ax: Axes | None = None,
space: Literal["pixel", "wavelength"] = "wavelength",
figsize: tuple[float, float] | None = None,
) -> Figure:
"""Plot the residuals of pixel-to-wavelength or wavelength-to-pixel transformation.
Parameters
----------
ax
Matplotlib Axes object to plot on. If None, a new figure and axes are created.
space
The reference space used for plotting residuals. Options are 'pixel' for residuals
in pixel space or 'wavelength' for residuals in wavelength space.
figsize
The size of the figure in inches, if a new figure is created.
Returns
-------
matplotlib.figure.Figure
"""
if ax is None:
fig, ax = subplots(figsize=figsize, constrained_layout=True)
else:
fig = ax.figure
self.match_lines()
mpix = np.ma.concatenate(self.observed_line_locations).compressed()
mwav = np.ma.concatenate(self.catalog_line_locations).compressed()
if space == "wavelength":
twav = self.solution.pix_to_wav(mpix)
ax.plot(mwav, mwav - twav, ".")
ax.text(
0.98,
0.95,
f"RMS = {np.sqrt(((mwav - twav) ** 2).mean()):4.2f} {self._unit_str}",
transform=ax.transAxes,
ha="right",
va="top",
)
setp(
ax,
xlabel=f"Wavelength [{self._unit_str}]",
ylabel=f"Residuals [{self._unit_str}]",
)
elif space == "pixel":
tpix = self.solution.wav_to_pix(mwav)
ax.plot(mpix, mpix - tpix, ".")
ax.text(
0.98,
0.95,
f"RMS = {np.sqrt(((mpix - tpix) ** 2).mean()):4.2f} pix",
transform=ax.transAxes,
ha="right",
va="top",
)
setp(ax, xlabel="Pixel", ylabel="Residuals [pix]")
else:
raise ValueError("Invalid space specified for plotting residuals.")
ax.axhline(0, c="k", lw=1, ls="--")
return fig