import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.rc('font', size=13)

fw = 10
nd, ncd = 31, 13
xd, xcd = np.arange(nd), np.arange(ncd)

tr = np.poly1d([-0.01, 0.2, 7.0])
spectrum = np.zeros((ncd, nd))
for i,x in enumerate(xd):
    spectrum[:,i] = norm(tr(x), 1.0).pdf(xcd)

fig, ax = plt.subplots(figsize=(fw, fw*(ncd/nd)), constrained_layout=True)
ax.imshow(spectrum, origin='lower')
ax.plot(xd, tr(xd), 'k')
ax.set_xticks(xd+0.5, minor=True)
ax.set_yticks(xcd+0.5, minor=True)
ax.grid(alpha=0.25, lw=1, which='minor')
plt.setp(ax, xlabel='Dispersion axis', ylabel='Cross-dispersion axis')
fig.show()