import numpy as np
import matplotlib.pyplot as plt
from astropy.wcs import WCS
import astropy.units as u
from specreduce.utils.synth_data import make_2d_arc_image

non_linear_header = {
    'CTYPE1': 'AWAV-GRA',  # Grating dispersion function with air wavelengths
    'CUNIT1': 'Angstrom',  # Dispersion units
    'CRPIX1': 719.8,       # Reference pixel [pix]
    'CRVAL1': 7245.2,      # Reference value [Angstrom]
    'CDELT1': 2.956,       # Linear dispersion [Angstrom/pix]
    'PV1_0': 4.5e5,        # Grating density [1/m]
    'PV1_1': 1,            # Diffraction order
    'PV1_2': 27.0,         # Incident angle [deg]
    'PV1_3': 1.765,        # Reference refraction
    'PV1_4': -1.077e6,     # Refraction derivative [1/m]
    'CTYPE2': 'PIXEL',     # Spatial detector coordinates
    'CUNIT2': 'pix',       # Spatial units
    'CRPIX2': 1,           # Reference pixel
    'CRVAL2': 0,           # Reference value
    'CDELT2': 1            # Spatial units per pixel
}

linear_header = {
    'CTYPE1': 'AWAV',  # Grating dispersion function with air wavelengths
    'CUNIT1': 'Angstrom',  # Dispersion units
    'CRPIX1': 719.8,       # Reference pixel [pix]
    'CRVAL1': 7245.2,      # Reference value [Angstrom]
    'CDELT1': 2.956,       # Linear dispersion [Angstrom/pix]
    'CTYPE2': 'PIXEL',     # Spatial detector coordinates
    'CUNIT2': 'pix',       # Spatial units
    'CRPIX2': 1,           # Reference pixel
    'CRVAL2': 0,           # Reference value
    'CDELT2': 1            # Spatial units per pixel
}

non_linear_wcs = WCS(non_linear_header)
linear_wcs = WCS(linear_header)

# this re-creates Paper III, Figure 5
pix_array = 200 + np.arange(1400)
nlin = non_linear_wcs.spectral.pixel_to_world(pix_array)
lin = linear_wcs.spectral.pixel_to_world(pix_array)
resid = (nlin - lin).to(u.Angstrom)
plt.plot(pix_array, resid)
plt.xlabel("Pixel")
plt.ylabel("Correction (Angstrom)")
plt.show()

nlin_im = make_2d_arc_image(
    nx=600,
    ny=512,
    linelists=['HeI', 'NeI'],
    line_fwhm=3,
    wave_air=True,
    wcs=non_linear_wcs
)
lin_im = make_2d_arc_image(
    nx=600,
    ny=512,
    linelists=['HeI', 'NeI'],
    line_fwhm=3,
    wave_air=True,
    wcs=linear_wcs
)

# subtracting the linear simulation from the non-linear one shows how the
# positions of lines diverge between the two cases
plt.imshow(nlin_im.data - lin_im.data)
plt.show()