import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.rc('font', size=13)

fw = 10
nd, ncd = 31, 13
xd, xcd = np.arange(nd), np.arange(ncd)

tr = np.full_like(xd, 6)
tr[:6] = 4
tr[15:23] = 8

spectrum = np.zeros((ncd, nd))

for i,x in enumerate(xd):
    spectrum[:,i] = norm(tr[i], 1.0).pdf(xcd)

plt.rc('font', size=13)
fig, ax = plt.subplots(figsize=(fw, fw*(ncd/nd)), constrained_layout=True)
ax.imshow(spectrum, origin='lower')
ax.plot(xd, tr, 'k')
ax.set_xticks(xd+0.5, minor=True)
ax.set_yticks(xcd+0.5, minor=True)
ax.grid(alpha=0.25, lw=1, which='minor')
plt.setp(ax, xlabel='Dispersion axis', ylabel='Cross-dispersion axis')
fig.show()