Source code for specreduce.wavecal1d

import warnings
from functools import cached_property
from math import isclose
from typing import Sequence, Literal

import astropy.units as u
import numpy as np
from astropy.modeling import models, Model, fitting, CompoundModel
from matplotlib.pyplot import Axes, Figure, setp, subplots
from numpy.ma import MaskedArray
from numpy.typing import ArrayLike
from scipy import optimize, ndimage
from scipy.spatial import KDTree

from specreduce.calibration_data import load_pypeit_calibration_lines
from specreduce.compat import Spectrum
from specreduce.line_matching import find_arc_lines

__all__ = ["WavelengthCalibration1D"]

from specreduce.wavesol1d import WavelengthSolution1D


def _format_linelist(lst: ArrayLike) -> MaskedArray:
    """Force a line list into a MaskedArray with a shape of (n, 2) where n is the number of lines.

    Parameters
    ----------
    lst
        Input array of centroids or centroids with amplitudes. Must be either:
            - A 1D array with a shape [n] for centroids.
            - A 2D array with a shape [n, 2] for centroids and amplitudes.

    Returns
    -------
    numpy.ma.MaskedArray
        Formatted and standardized line list array with shape [n, 2], where each row
        contains a line centroid and amplitude.

    Raises
    ------
    ValueError
        If the input line list does not meet the specified dimensional or shape
        requirements.
    """
    lst: MaskedArray = MaskedArray(lst, copy=True)
    lst.mask = np.ma.getmaskarray(lst)

    if lst.ndim > 2 or lst.ndim == 2 and lst.shape[1] > 2:
        raise ValueError(
            "Line lists must be 1D with a shape [n] (centroids) or "
            "2D with a shape [n, 2] (centroids and amplitudes)."
        )

    if lst.ndim == 1:
        lst = MaskedArray(np.tile(lst[:, None], [1, 2]))
        lst[:, 1] = 0.0
        lst.mask[:, :] = lst.mask.any(axis=1)[:, None]

    return lst[np.argsort(lst.data[:, 0])]


def _unclutter_text_boxes(labels: Sequence) -> None:
    """Remove overlapping labels from the plot.

    Removes overlapping text labels from a set of matplotlib label objects. The function iterates
    over all combinations of labels, checks for overlaps among their bounding boxes, and removes
    the label with the lower z-order in case of an overlap.

    Parameters
    ----------
    labels
        A list of matplotlib.text.Text objects.
    """
    to_remove = set()
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            l1 = labels[i]
            l2 = labels[j]
            bbox1 = l1.get_window_extent()
            bbox2 = l2.get_window_extent()
            if bbox1.overlaps(bbox2):
                if l1.zorder < l2.zorder:
                    to_remove.add(l1)
                else:
                    to_remove.add(l2)

    for label in to_remove:
        label.remove()


[docs] class WavelengthCalibration1D: def __init__( self, arc_spectra: Spectrum | Sequence[Spectrum] | None = None, obs_lines: ArrayLike | Sequence[ArrayLike] | None = None, line_lists: ArrayLike | None = None, unit: u.Unit = u.angstrom, ref_pixel: float | None = None, pix_bounds: tuple[int, int] | None = None, line_list_bounds: None | tuple[float, float] = None, n_strogest_lines: None | int = None, wave_air: bool = False, ) -> None: """A class for wavelength calibration of one-dimensional spectra. This class is designed to facilitate wavelength calibration of one-dimensional spectra, with support for both direct input of line lists and observed spectra. It uses a polynomial model for fitting the wavelength solution and offers features to incorporate catalog lines and observed line positions. Parameters ---------- arc_spectra Arc spectra provided as ``Spectrum`` objects for wavelength fitting, by default None. This parameter and ``obs_lines`` cannot be provided simultaneously. obs_lines Pixel positions of observed spectral lines for wavelength fitting, by default None. This parameter and ``arc_spectra`` cannot be provided simultaneously. line_lists Catalogs of spectral line wavelengths for wavelength calibration. Provide either an array of line wavelengths or a list of `PypeIt <https://github.com/pypeit/PypeIt>`_ catalog names. If `None`, no line lists are used. You can query the list of available catalog names via `~specreduce.calibration_data.get_available_line_catalogs`. unit The unit of the wavelength calibration, by default ``astropy.units.Angstrom``. ref_pixel The reference pixel in which the wavelength solution will be centered. pix_bounds Lower and upper pixel bounds for fitting, defined as a tuple (min, max). If ``obs_lines`` is provided, this parameter is mandatory. line_list_bounds Wavelength bounds as a tuple (min, max) for filtering usable spectral lines from the provided line lists. n_strogest_lines The number of strongest lines to be included from the line lists. If `None`, all are included. wave_air Boolean indicating whether the input wavelengths correspond to air rather than vacuum; by default `False`, meaning vacuum wavelengths. """ self.unit = unit self._unit_str = unit.to_string("latex") self.degree = None self.ref_pixel = ref_pixel self.nframes = 0 if ref_pixel is not None and ref_pixel < 0: raise ValueError("Reference pixel must be positive.") self.arc_spectra: list[Spectrum] | None = None self.bounds_pix: tuple[int, int] | None = pix_bounds self.bounds_wav: tuple[float, float] | None = None self._cat_lines: list[MaskedArray] | None = None self._obs_lines: list[MaskedArray] | None = None self._trees: list[KDTree] | None = None self._fit: optimize.OptimizeResult | None = None self.solution = WavelengthSolution1D(None, pix_bounds, unit) # Read and store the observational data if given. The user can provide either a list of arc # spectra as Spectrum objects or a list of line pixel position arrays. An attempt to give # both raises an error. if arc_spectra is not None and obs_lines is not None: raise ValueError("Only one of arc_spectra or obs_lines can be provided.") if arc_spectra is not None: self.arc_spectra = [arc_spectra] if isinstance(arc_spectra, Spectrum) else arc_spectra self.nframes = len(self.arc_spectra) for s in self.arc_spectra: if s.data.ndim > 1: raise ValueError("The arc spectrum must be one dimensional.") if len(set([s.data.size for s in self.arc_spectra])) != 1: raise ValueError("All arc spectra must have the same length.") self.bounds_pix = (0, self.arc_spectra[0].shape[0]) self.solution.bounds_pix = self.bounds_pix if self.ref_pixel is None: self.ref_pixel = self.arc_spectra[0].data.size / 2 elif obs_lines is not None: self.observed_lines = obs_lines self.nframes = len(self._obs_lines) if self.bounds_pix is None: raise ValueError("Must give pixel bounds when providing observed line positions.") if self.ref_pixel is None: raise ValueError("Must give reference pixel when providing observed lines.") # Read the line lists if given. The user can provide an array of line wavelength positions # or a list of line list names (used by `load_pypeit_calibration_lines`) for each arc # spectrum. if line_lists is not None: if not isinstance(line_lists, (tuple, list)): line_lists = [line_lists] if len(line_lists) != self.nframes: raise ValueError("The number of line lists must match the number of arc spectra.") self._read_linelists( line_lists, line_list_bounds=line_list_bounds, wave_air=wave_air, n_strongest=n_strogest_lines, ) def _line_match_distance(self, x: ArrayLike, model: Model, max_distance: float = 100) -> float: """Compute the sum of distances between catalog lines and transformed observed lines. This function evaluates the pixel-to-wavelength model at the observed line positions, queries the nearest catalog line (via KDTree), and sums the distances after clipping them at `max_distance`. The result is suitable as a scalar objective for global optimization of the wavelength solution. Parameters ---------- x Pixel-to-wavelength model parameters (e.g., Polynomial1D coefficients c0..cN). model The pixel-to-wavelength model to be evaluated. max_distance Upper bound used to clip individual distances before summation. Returns ------- float Sum of nearest-neighbor distances between transformed observed lines and catalog lines. """ total_distance = 0.0 for t, l in zip(self._trees, self.observed_line_locations): transformed_lines = model.evaluate(l, -self.ref_pixel, *x)[:, None] total_distance += np.clip(t.query(transformed_lines)[0], 0, max_distance).sum() return total_distance def _read_linelists( self, line_lists: Sequence, line_list_bounds: None | tuple[float, float] = None, wave_air: bool = False, n_strongest: None | int = None, ) -> None: """Read and processes line lists. Parameters ---------- line_lists A collection of line lists that can either be arrays of wavelengths or `PypeIt <https://github.com/pypeit/PypeIt>`_ lamp names. You can query the list of available catalog names via `~specreduce.calibration_data.get_available_line_catalogs`. line_list_bounds A tuple specifying the minimum and maximum wavelength bounds. Only wavelengths within this range are retained. wave_air If True, convert the vacuum wavelengths used by `PypeIt <https://github.com/pypeit/PypeIt>`_ to air wavelengths. n_strongest The number of strongest lines to be used. If `None`, all lines are used. """ lines = [] for lst in line_lists: if isinstance(lst, np.ndarray): lines.append(lst) else: if isinstance(lst, str): lst = [lst] lines.append([]) for ll in lst: line_table = load_pypeit_calibration_lines(ll, wave_air=wave_air) if n_strongest is not None: ix = np.argsort(line_table["amplitude"].value)[::-1] lines[-1].append(line_table[ix][:n_strongest]["wavelength"].to( self.unit).value) else: lines[-1].append(line_table["wavelength"].to(self.unit).value) lines[-1] = np.ma.masked_array(np.sort(np.concatenate(lines[-1]))) if line_list_bounds is not None: for i, lst in enumerate(lines): lines[i] = lst[(lst >= line_list_bounds[0]) & (lst <= line_list_bounds[1])] self.catalog_lines = lines self._create_trees() def _create_trees(self) -> None: """Initialize the KDTree instances for the current set of catalog line locations.""" self._trees = [KDTree(lst.compressed()[:, None]) for lst in self.catalog_line_locations]
[docs] def find_lines(self, fwhm: float, noise_factor: float = 1.0) -> None: """Find lines in the provided arc spectra. Determines the spectral lines within each spectrum of the arc spectra based on the provided initial guess for the line Full Width at Half Maximum (FWHM). Parameters ---------- fwhm Initial guess for the FWHM for the spectral lines, used as a parameter in the ``find_arc_lines`` function to locate and identify spectral arc lines. noise_factor The factor to multiply the uncertainty by to determine the noise threshold in the `~specutils.fitting.find_lines_threshold` routine. """ if self.arc_spectra is None: raise ValueError("Must provide arc spectra to find lines.") line_lists = [] for i, arc in enumerate(self.arc_spectra): lines = find_arc_lines(arc, fwhm, noise_factor=noise_factor) ix = np.round(lines["centroid"].value).astype(int) if np.any((ix < 0) | (ix >= arc.shape[0])): raise ValueError( "Error in arc line identification. Try increasing ``noise_factor``." ) amplitudes = ndimage.maximum_filter1d(arc.flux.value, 5)[ix] line_lists.append( np.ma.masked_array(np.transpose([lines["centroid"].value, amplitudes])) ) self.observed_lines = line_lists
def _create_model(self, degree: int, coeffs: None | ArrayLike = None) -> CompoundModel: """Initialize the polynomial model with the given degree and an optional base model. This method sets up a polynomial transformation based on the reference pixel and degree. If coefficients are provided, they are copied to the initialized model up to the degree specified. Parameters ---------- degree Degree of the polynomial model to be initialized. coeffs Optional initial polynomial coefficients. """ self.degree = degree pars = {} if coeffs is not None: nc = min(degree + 1, len(coeffs)) pars = {f"c{i}": c for i, c in enumerate(coeffs[:nc])} return models.Shift(-self.ref_pixel) | models.Polynomial1D(self.degree, **pars)
[docs] def fit_lines( self, pixels: ArrayLike, wavelengths: ArrayLike, degree: int = 3, match_obs: bool = False, match_cat: bool = False, refine_fit: bool = True, refine_max_distance: float = 5.0, refined_fit_degree: int | None = None, ) -> WavelengthSolution1D: """Fit the pixel-to-wavelength model using provided line pairs. This method fits the pixel-to-wavelength transformation using explicitly provided pairs of pixel coordinates and their corresponding wavelengths via a linear least-squares fit Optionally, the provided pixel and wavelength values can be "snapped" to the nearest values present in the internally stored observed line list and catalog line list, respectively. This allows the inputs to be approximate, as the snapping step selects the nearest precise centroids and catalog values when available. Parameters ---------- pixels An array of pixel positions corresponding to known spectral lines. wavelengths An array of the same size as ``pixels``, containing the known wavelengths corresponding to the given pixel positions. degree The polynomial degree for the wavelength solution. match_obs If True, snap the input ``pixels`` values to the nearest pixel values found in ``self.observed_line_locations`` (if available). This helps ensure the fit uses the precise centroids detected by `find_lines` or provided initially. match_cat If True, snap the input ``wavelengths`` values to the nearest wavelength values found in ``self.catalog_line_locations`` (if available). This ensures the fit uses the precise catalog wavelengths. refine_fit If True (default), automatically call the ``refine_fit`` method immediately after the global optimization to improve the solution using a least-squares fit on matched lines. refine_max_distance Maximum allowed separation between catalog and observed lines for them to be considered a match during ``refine_fit``. Ignored if ``refine_fit`` is False. refined_fit_degree The polynomial degree for the refined fit. Can be higher than ``degree``. If ``None``, equals to ``degree``. """ pixels = np.asarray(pixels) wavelengths = np.asarray(wavelengths) if pixels.size != wavelengths.size: raise ValueError("The sizes of pixel and wavelength arrays must match.") nlines = pixels.size if nlines < 2: raise ValueError("Need at least two lines for a fit") if self.bounds_pix is None: raise ValueError("Cannot fit without pixel bounds set.") # Match the input wavelengths to catalog lines. if match_cat: if self._cat_lines is None: raise ValueError("Cannot fit without catalog lines set.") tree = KDTree( np.concatenate([c.compressed() for c in self.catalog_line_locations])[:, None] ) ix = tree.query(wavelengths[:, None])[1] wavelengths = tree.data[ix][:, 0] # Match the input pixel values to observed pixel values. if match_obs: if self._obs_lines is None: raise ValueError("Cannot fit without observed lines set.") tree = KDTree( np.concatenate([c.compressed() for c in self.observed_line_locations])[:, None] ) ix = tree.query(pixels[:, None])[1] pixels = tree.data[ix][:, 0] fitter = fitting.LinearLSQFitter() shift, model = self._create_model(degree) if model.degree > nlines: warnings.warn( "The degree of the polynomial model is higher than the number of lines. " "Fixing the higher-order coefficients to zero." ) for i in range(nlines, model.degree + 1): model.fixed[f"c{i}"] = True model = fitter(model, pixels - self.ref_pixel, wavelengths) for i in range(model.degree + 1): model.fixed[f"c{i}"] = False self.solution.p2w = shift | model can_match = self._cat_lines is not None and self._obs_lines is not None if refine_fit and can_match: self.refine_fit(refined_fit_degree, max_match_distance=refine_max_distance) else: if can_match: self.match_lines() return self.solution
[docs] def fit_dispersion( self, wavelength_bounds: tuple[float, float], dispersion_bounds: tuple[float, float], higher_order_limits: Sequence[float] | None = None, degree: int = 3, popsize: int = 30, max_distance: float = 100, refine_fit: bool = True, refine_max_distance: float = 5.0, refined_fit_degree: int | None = None, ) -> WavelengthSolution1D: """Calculate a wavelength solution using all the catalog and observed lines. This method estimates a wavelength solution without pre-matched pixel–wavelength pairs, making it suitable for automated pipelines on stable, well-characterized spectrographs. It uses differential evolution to optimize the polynomial parameters that minimize the distance between the predicted wavelengths of the observed lines and their nearest catalog lines. The resulting solution can optionally be refined with a least-squares fit to automatically matched lines. Parameters ---------- wavelength_bounds (min, max) bounds for the wavelength at ``ref_pixel``; used as an optimization constraint. dispersion_bounds (min, max) bounds for the dispersion d(wavelength)/d(pixel) at ``ref_pixel``; used as an optimization constraint. higher_order_limits Absolute limits for the higher-order polynomial coefficients. Each coefficient is constrained to [-limit, limit]. If provided, the number of limits must equal (polynomial degree - 1). degree The polynomial degree for the wavelength solution. popsize Population size for ``scipy.optimize.differential_evolution``. Larger values can improve the chance of finding the global minimum at the cost of additional time. max_distance Maximum wavelength separation used when associating observed and catalog lines in the optimization. Distances larger than this threshold are clipped to this value in the cost function to limit the impact of outliers. refine_fit If True (default), call ``refine_fit`` after global optimization to improve the solution using a least-squares fit on matched lines. refine_max_distance Maximum allowed separation between catalog and observed lines for them to be considered a match during ``refine_fit``. Ignored if ``refine_fit`` is False. refined_fit_degree The polynomial degree for the refined fit. Can be higher than ``degree``. If ``None``, equals to ``degree``. """ # Define bounds for differential_evolution. bounds = [np.asarray(wavelength_bounds), np.asarray(dispersion_bounds)] model = self._create_model(degree) if higher_order_limits is not None: if len(higher_order_limits) != model[1].degree - 1: raise ValueError( "The number of higher-order limits must match the degree of the polynomial " "model minus one." ) for v in higher_order_limits: bounds.append(np.asarray([-v, v])) else: for i in range(2, model[1].degree + 1): bounds.append( np.array([-1, 1]) * 10 ** (np.log10(np.mean(dispersion_bounds)) - 2 * i) ) bounds = np.array(bounds) self._fit = optimize.differential_evolution( lambda x: self._line_match_distance(x, model, max_distance), bounds=bounds, popsize=popsize, init="sobol", ) self.solution.p2w = self._create_model(degree, coeffs=self._fit.x) can_match = self._cat_lines is not None and self._obs_lines is not None if refine_fit: self.refine_fit(refined_fit_degree, max_match_distance=refine_max_distance) else: if can_match: self.match_lines() return self.solution
[docs] def refine_fit( self, degree: None | int = None, max_match_distance: float = 5.0, max_iter: int = 5 ) -> WavelengthSolution1D: """Refine the pixel-to-wavelength transformation fit. Fits (or re-fits) the polynomial wavelength solution using the currently matched pixel–wavelength pairs. Optionally adjusts the polynomial degree, filters matches by a maximum pixel-space separation, and iterates the fit. Parameters ---------- degree The polynomial degree for the wavelength solution. If ``None``, the degree previously set by the `~WavelengthCalibration1D.fit_lines` or `~WavelengthCalibration1D.fit_dispersion` method will be used. max_match_distance Maximum allowable distance used to identify matched pixel and wavelength data points. Points exceeding the bound will not be considered in the fit. max_iter Maximum number of fitting iterations. """ # Create a new model with the current parameters if degree is specified. if degree is not None and degree != self.degree: model = self._create_model(degree, coeffs=self.solution.p2w[1].parameters) else: model = self.solution.p2w shift, poly = model fitter = fitting.LinearLSQFitter() rms = np.nan for i in range(max_iter): self.match_lines(max_match_distance) matched_pix = np.ma.concatenate(self.observed_line_locations).compressed() matched_wav = np.ma.concatenate(self.catalog_line_locations).compressed() rms_new = np.sqrt(((matched_wav - model(matched_pix)) ** 2).mean()) if isclose(rms_new, rms): break model = shift | fitter(poly, matched_pix - self.ref_pixel, matched_wav) rms = rms_new self.solution.p2w = model return self.solution
@property def degree(self) -> None | int: return self._degree @degree.setter def degree(self, degree: int | None): if degree is not None and degree < 1: raise ValueError("Degree must be at least 1.") self._degree = degree @property def observed_lines(self) -> None | list[MaskedArray]: """Pixel positions and amplitudes of the observed lines as a list of masked arrays.""" return self._obs_lines
[docs] @cached_property def observed_line_locations(self) -> None | list[MaskedArray]: """Pixel positions of the observed lines as a list of masked arrays.""" if self._obs_lines is None: return None else: return [line[:, 0] for line in self._obs_lines]
[docs] @cached_property def observed_line_amplitudes(self) -> None | list[MaskedArray]: """Amplitudes of the observed lines as a list of masked arrays.""" if self._obs_lines is None: return None else: return [line[:, 1] for line in self._obs_lines]
@observed_lines.setter def observed_lines(self, line_lists: ArrayLike | list[ArrayLike]): if not isinstance(line_lists, Sequence): line_lists = [line_lists] self._obs_lines = [] for lst in line_lists: self._obs_lines.append(_format_linelist(lst)) if hasattr(self, "observed_line_locations"): del self.observed_line_locations if hasattr(self, "observed_line_amplitudes"): del self.observed_line_amplitudes @property def catalog_lines(self) -> None | list[MaskedArray]: """Catalog line wavelengths as a list of masked arrays.""" return self._cat_lines
[docs] @cached_property def catalog_line_locations(self) -> None | list[MaskedArray]: """Pixel positions of the catalog lines as a list of masked arrays.""" if self._cat_lines is None: return None else: return [line[:, 0] for line in self._cat_lines]
[docs] @cached_property def catalog_line_amplitudes(self) -> None | list[MaskedArray]: """Amplitudes of the catalog lines as a list of masked arrays.""" if self._obs_lines is None: return None else: return [line[:, 1] for line in self._cat_lines]
@catalog_lines.setter def catalog_lines(self, line_lists: ArrayLike | list[ArrayLike]): if not isinstance(line_lists, Sequence): line_lists = [line_lists] self._cat_lines = [] for lst in line_lists: self._cat_lines.append(_format_linelist(lst)) if hasattr(self, "catalog_line_locations"): del self.catalog_line_locations if hasattr(self, "catalog_line_amplitudes"): del self.catalog_line_amplitudes
[docs] def match_lines(self, max_distance: float = 5) -> None: """Match the observed lines to theoretical lines. Parameters ---------- max_distance The maximum allowed distance between the catalog and observed lines for them to be considered a match. """ for iframe, tree in enumerate(self._trees): l, ix = tree.query( self.solution.p2w(self.observed_line_locations[iframe].data)[:, None], distance_upper_bound=max_distance, ) m = np.isfinite(l) # Check for observed lines that match a catalog line. # Remove all but the nearest match. This isn't an optimal solution, # we could also iterate the match by removing the currently matched # lines, but this works for now. uix, cnt = np.unique(ix[m], return_counts=True) if any(n := cnt > 1): for i, c in zip(uix[n], cnt[n]): s = ix == i r = np.zeros(c, dtype=bool) r[np.argmin(l[s])] = True m[s] = r self._cat_lines[iframe].mask[:, :] = True self._cat_lines[iframe].mask[ix[m], :] = False self._obs_lines[iframe].mask[:, :] = ~m[:, None]
[docs] def remove_unmatched_lines(self) -> None: """Remove unmatched lines from observation and catalog line data.""" self.observed_lines = [lst.compressed().reshape([-1, 2]) for lst in self._obs_lines] self.catalog_lines = [lst.compressed().reshape([-1, 2]) for lst in self._cat_lines] self._create_trees()
[docs] def rms(self, space: Literal["pixel", "wavelength"] = "wavelength") -> float: """Compute the RMS of the residuals between matched lines in the pixel or wavelength space. Parameters ---------- space The space in which to calculate the RMS residual. If 'wavelength', the calculation is performed in the wavelength space. If 'pixel', it is performed in the pixel space. Default is 'wavelength'. Returns ------- float """ self.match_lines() mpix = np.ma.concatenate(self.observed_line_locations).compressed() mwav = np.ma.concatenate(self.catalog_line_locations).compressed() if space == "wavelength": return np.sqrt(((mwav - self.solution.p2w(mpix)) ** 2).mean()) elif space == "pixel": return np.sqrt(((mpix - self.solution.w2p(mwav)) ** 2).mean()) else: raise ValueError("Space must be either 'pixel' or 'wavelength'")
def _plot_lines( self, kind: Literal["observed", "catalog"], frames: int | Sequence[int] | None = None, axes: Axes | Sequence[Axes] | None = None, figsize: tuple[float, float] | None = None, plot_labels: bool | Sequence[bool] = True, map_x: bool = False, label_kwargs: dict | None = None, ) -> Figure: """ Plot lines with optional features such as wavelength mapping and label customization. Parameters ---------- kind Specifies the line list to plot. frames Frame indices to plot. If None, all frames are plotted. axes Axes object(s) where the lines should be plotted. If None, new Axes are generated. figsize Size of the figure to use if creating new Axes. Ignored if axes are provided. plot_labels Flag(s) indicating whether to display labels for the lines. If a single value is provided, it is applied to all frames. map_x If True, maps the x-axis values between pixel and wavelength space. label_kwargs Additional keyword arguments to customize the label style. Returns ------- Figure The Figure object containing the plotted spectral lines. """ largs = dict(backgroundcolor="w", rotation=90, size="small") if label_kwargs is not None: largs.update(label_kwargs) if frames is None: frames = np.arange(self.nframes) else: frames = np.atleast_1d(frames) if axes is None: fig, axes = subplots( frames.size, 1, figsize=figsize, constrained_layout=True, sharex="all" ) elif isinstance(axes, Axes): fig = axes.figure axes = [axes] else: fig = axes[0].figure axes = np.atleast_1d(axes) if isinstance(plot_labels, bool): plot_labels = np.full(frames.size, plot_labels, dtype=bool) if map_x and self.solution.p2w is None: raise ValueError("Cannot map between pixels and wavelengths without a fitted model.") if kind == "observed": transform = self.solution.pix_to_wav if map_x else lambda x: x linelists = self.observed_lines spectra = self.arc_spectra lc = "C0" else: transform = self.solution.wav_to_pix if map_x else lambda x: x linelists = self.catalog_lines spectra = None lc = "C1" ypad = 1.3 labels = [] for iframe, (ax, frame) in enumerate(zip(axes, frames)): if spectra is not None: spc = self.arc_spectra[iframe] vmax = np.nanmax(spc.flux.value) ax.plot(transform(spc.spectral_axis.value), spc.flux.value / vmax, "k") else: vmax = 1.0 if linelists is not None: labels.append([]) # Loop over individual lines in the line list. for i in range(linelists[iframe].shape[0]): c, a = linelists[iframe].data[i] ls = "-" if linelists[iframe].mask[i, 0] == 0 else ":" ax.plot(transform([c, c]), [a / vmax + 0.1, 1.27], c=lc, ls=ls, zorder=-100) if plot_labels[iframe]: lloc = transform(c) labels[-1].append( ax.text( lloc, ypad, np.round(lloc, 4 - 1 - int(np.floor(np.log10(lloc)))), ha="center", va="top", **largs, ) ) labels[-1][-1].set_clip_on(True) labels[-1][-1].zorder = a if (kind == "observed" and not map_x) or (kind == "catalog" and map_x): xlabel = "Pixel" else: xlabel = f"Wavelength {self._unit_str}" if kind == "catalog": axes[0].xaxis.set_label_position("top") axes[0].xaxis.tick_top() setp(axes[0], xlabel=xlabel) for ax in axes[1:]: ax.set_xticklabels([]) else: setp(axes[-1], xlabel=xlabel) for ax in axes[:-1]: ax.set_xticklabels([]) xlims = np.array([ax.get_xlim() for ax in axes]) setp(axes, xlim=(xlims[:, 0].min(), xlims[:, 1].max()), yticks=[]) if linelists is not None: fig.canvas.draw() for i in range(len(frames)): if plot_labels[i]: # Calculate the label bounding box upper limits and adjust the y-axis limits. tr_to_data = axes[i].transData.inverted() ymax = -np.inf for lb in labels[i]: ymax = max(ymax, tr_to_data.transform(lb.get_window_extent().p1)[1]) setp(axes[i], ylim=(-0.04, ymax * 1.06)) # Remove the overlapping labels prioritizing the high-amplitude lines. _unclutter_text_boxes(labels[i]) return fig
[docs] def plot_catalog_lines( self, frames: int | Sequence[int] | None = None, axes: Axes | Sequence[Axes] | None = None, figsize: tuple[float, float] | None = None, plot_labels: bool | Sequence[bool] = True, map_to_pix: bool = False, label_kwargs: dict | None = None, ) -> Figure: """Plot the catalog lines. Parameters ---------- frames Specifies the frames to be plotted. If an integer, only one frame is plotted. If a sequence, the specified frames are plotted. If None, default selection or all frames are plotted. axes The matplotlib axes where catalog data will be plotted. If provided, the function will plot on these axes. If None, new axes will be created. figsize Specifies the dimensions of the figure as (width, height). If None, the default dimensions are used. plot_labels If True, the numerical values associated with the catalog data will be displayed in the plot. If False, only the graphical representation of the lines will be shown. map_to_pix Indicates whether the catalog data should be mapped to pixel coordinates before plotting. If True, the data is converted to pixel coordinates. label_kwargs Specifies the keyword arguments for the line label text objects. Returns ------- Figure The matplotlib figure containing the plotted catalog lines. """ return self._plot_lines( "catalog", frames=frames, axes=axes, figsize=figsize, plot_labels=plot_labels, map_x=map_to_pix, label_kwargs=label_kwargs, )
[docs] def plot_observed_lines( self, frames: int | Sequence[int] | None = None, axes: Axes | Sequence[Axes] | None = None, figsize: tuple[float, float] | None = None, plot_labels: bool | Sequence[bool] = True, map_to_wav: bool = False, label_kwargs: dict | None = None, ) -> Figure: """Plot observed spectral lines for the given arc spectra. Parameters ---------- frames Specifies the frame(s) for which the plot is to be generated. If None, all frames are plotted. When an integer is provided, a single frame is used. For a sequence of integers, multiple frames are plotted. axes Axes object(s) to plot the spectral lines on. If None, new axes are created. figsize Dimensions of the figure to be created, specified as a tuple (width, height). Ignored if ``axes`` is provided. plot_labels If True, plots the numerical values of the observed lines at their respective locations on the graph. map_to_wav Determines whether to map the x-axis values to wavelengths. label_kwargs Specifies the keyword arguments for the line label text objects. Returns ------- Figure The matplotlib figure containing the observed lines plot. """ fig = self._plot_lines( "observed", frames=frames, axes=axes, figsize=figsize, plot_labels=plot_labels, map_x=map_to_wav, label_kwargs=label_kwargs, ) for ax in fig.axes: ax.autoscale(True, "x", tight=True) return fig
[docs] def plot_fit( self, frames: Sequence[int] | int | None = None, figsize: tuple[float, float] | None = None, plot_labels: bool = True, obs_to_wav: bool = False, cat_to_pix: bool = False, label_kwargs: dict | None = None, ) -> Figure: """Plot the fitted catalog and observed lines for the specified arc spectra. Parameters ---------- frames The indices of the frames to plot. If `None`, all frames from 0 to ``self.nframes - 1`` are plotted. figsize Defines the width and height of the figure in inches. If `None`, the default size is used. plot_labels If `True`, print line locations over the plotted lines. Can also be a list with the same length as ``frames``. obs_to_wav If `True`, transform the x-axis of observed lines to the wavelength domain using `self._p2w`, if available. cat_to_pix If `True`, transforms catalog data points to pixel values before plotting. label_kwargs Specifies the keyword arguments for the line label text objects. Returns ------- matplotlib.figure.Figure The figure object containing the generated subplots. """ if frames is None: frames = np.arange(self.nframes) else: frames = np.atleast_1d(frames) fig, axs = subplots(2 * frames.size, 1, constrained_layout=True, figsize=figsize) self.plot_catalog_lines( frames, axs[0::2], plot_labels=plot_labels, map_to_pix=cat_to_pix, label_kwargs=label_kwargs, ) self.plot_observed_lines( frames, axs[1::2], plot_labels=plot_labels, map_to_wav=obs_to_wav, label_kwargs=label_kwargs, ) xlims = np.array([ax.get_xlim() for ax in axs[::2]]) if obs_to_wav: setp(axs, xlim=(xlims[:, 0].min(), xlims[:, 1].max())) else: setp(axs[::2], xlim=(xlims[:, 0].min(), xlims[:, 1].max())) setp(axs[0], yticks=[], xlabel=f"Wavelength [{self._unit_str}]") for ax in axs[1:-1]: ax.set_xlabel("") ax.set_xticklabels("") axs[0].xaxis.set_label_position("top") axs[0].xaxis.tick_top() return fig
[docs] def plot_residuals( self, ax: Axes | None = None, space: Literal["pixel", "wavelength"] = "wavelength", figsize: tuple[float, float] | None = None, ) -> Figure: """Plot the residuals of pixel-to-wavelength or wavelength-to-pixel transformation. Parameters ---------- ax Matplotlib Axes object to plot on. If None, a new figure and axes are created. space The reference space used for plotting residuals. Options are 'pixel' for residuals in pixel space or 'wavelength' for residuals in wavelength space. figsize The size of the figure in inches, if a new figure is created. Returns ------- matplotlib.figure.Figure """ if ax is None: fig, ax = subplots(figsize=figsize, constrained_layout=True) else: fig = ax.figure self.match_lines() mpix = np.ma.concatenate(self.observed_line_locations).compressed() mwav = np.ma.concatenate(self.catalog_line_locations).compressed() if space == "wavelength": twav = self.solution.pix_to_wav(mpix) ax.plot(mwav, mwav - twav, ".") ax.text( 0.98, 0.95, f"RMS = {np.sqrt(((mwav - twav) ** 2).mean()):4.2f} {self._unit_str}", transform=ax.transAxes, ha="right", va="top", ) setp( ax, xlabel=f"Wavelength [{self._unit_str}]", ylabel=f"Residuals [{self._unit_str}]", ) elif space == "pixel": tpix = self.solution.wav_to_pix(mwav) ax.plot(mpix, mpix - tpix, ".") ax.text( 0.98, 0.95, f"RMS = {np.sqrt(((mpix - tpix) ** 2).mean()):4.2f} pix", transform=ax.transAxes, ha="right", va="top", ) setp(ax, xlabel="Pixel", ylabel="Residuals [pix]") else: raise ValueError("Invalid space specified for plotting residuals.") ax.axhline(0, c="k", lw=1, ls="--") return fig