# Change of variables with autodiff¶

In this notebook we will be predicting the distribution of $$q=1/x$$ when $$x\sim N(\mu,\sigma)$$ with automatic differentiation. This is a follow up to the previous notebook How do distributions transform under a change of variables ?, which did not use autodiff.

by Kyle Cranmer, March 2, 2020

import numpy as np
import scipy.stats as scs
import matplotlib.pyplot as plt

mean=1.
std = .3
N = 10000

x = np.random.normal(mean, std, N)

x_plot = np.linspace(0.1,3,100)

_ = plt.hist(x, bins=50, density=True)
plt.plot(x_plot, scs.norm.pdf(x_plot, mean, std), c='r', lw=2)

[<matplotlib.lines.Line2D at 0x7ff680312520>]

def q(x):
return 1/x

q_ = q(x)
q_plot = q(x_plot)

plt.plot(x_plot, q_plot,label='q(x)')
_ = plt.hist(x, bins=50, density=True, label='p(x)')
plt.xlabel('x')
plt.ylabel('q or p(q)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff67055ef40>

mybins = np.linspace(0,3,50)
_ = plt.hist(q_, bins=mybins, density=True)
plt.xlabel('x')
plt.ylabel('p(x)')

Text(0, 0.5, 'p(x)')


## Do it by hand¶

We want to evaluate $$p_q(q) = \frac{p_x(x(q))}{ | dq/dx |}$$, which requires knowing the deriviative and how to invert from $$q \to x$$. The inversion is easy, it’s just $$x(q)=1/q$$. The derivative is $$dq/dx = \frac{-1}{x^2}$$, which in terms of $$q$$ is $$dq/dx = q^2$$.

_ = plt.hist(q_, bins=mybins, density=True, label='histogram')
plt.plot(q_plot, scs.norm.pdf(1/q_plot, mean, std)/q_plot/q_plot, c='r', lw=2, label='prediction')
plt.xlim((0,3))
plt.xlabel('x')
plt.ylabel('p(x)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff660a8df40>


Alternatively, we don’t need to know how to invert $$x(q)$$. Instead, we can start with x_plot and use the evaluated pairs (x_plot, q_plot=q(x_plot)). Then we can just use x_plot when we want $$x(q)$$.

Here is a plot of the inverse mad ethat way.

plt.plot(q_plot, x_plot, c='r', lw=2, label='inverse x(q)')
plt.xlim((0,3))
plt.xlabel('q')
plt.ylabel('x(q)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff680305bb0>


and here is a plot of our prediction using x_plot directly

_ = plt.hist(q_, bins=mybins, density=True, label='histogram')
plt.plot(q_plot, scs.norm.pdf(x_plot, mean, std)/np.power(x_plot,-2), c='r', lw=2, label='prediction')
plt.xlim((0,3))
plt.xlabel('x')
plt.ylabel('p(x)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff6803ef100>


## With Jax Autodiff for the derivatives¶

Now let’s do the same thing using Jax to calculate the derivatives. We will make a new function dq by applying the grad function of Jax to our own function q (eg. dq = grad(q)).

from jax import grad, vmap
import jax.numpy as np

#define the gradient with grad(q)
dq = grad(q)  #dq is a new python function
print(dq(.5)) # should be -4

-4.0

/Users/cranmer/anaconda3/envs/stat-ds-book/lib/python3.8/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')

# dq(x) #broadcasting won't work. Gives error:

#define the gradient with grad(q) that works with broadcasting

#print dq/dx for x=0.5, 1, 2
# it should be -1/x^2 =. -4, 1, -0.25

dq( np.array([.5, 1, 2.]))

DeviceArray([-4.  , -1.  , -0.25], dtype=float32)

#plot gradient
plt.plot(x_plot, -np.power(x_plot,-2), c='black', lw=2, label='-1/x^2')
plt.plot(x_plot, dq(x_plot), c='r', lw=2, ls='dashed', label='dq/dx from jax')
plt.xlabel('x')
plt.ylabel('dq/dx')
plt.legend()

<matplotlib.legend.Legend at 0x7ff690d4ef10>


We want to evaluate $$p_q(q) = \frac{p_x(x(q))}{ | dq/dx |}$$, which requires knowing how to invert from $$q \to x$$. That’s easy, it’s just $$x(q)=1/q$$. But we also have evaluated pairs (x_plot, q_plot), so we can just use x_plot when we want $$x(q)$$

Put it all together.

Again we can either invert x(q) by hand and use Jax for derivative:

_ = plt.hist(q_, bins=np.linspace(-1,3,50), density=True, label='histogram')
plt.plot(q_plot, scs.norm.pdf(1/q_plot, mean, std)/np.abs(dq(1/q_plot)), c='r', lw=2, label='prediction')
plt.xlim((0,3))
plt.xlabel('x')
plt.ylabel('p(x)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff65906a460>


or we can use the pairs x_plot, q_plot

_ = plt.hist(q_, bins=np.linspace(-1,3,50), density=True, label='histogram')
plt.plot(q_plot, scs.norm.pdf(x_plot, mean, std)/np.abs(dq(x_plot)), c='r', lw=2, label='prediction')
plt.xlim((0,3))
plt.xlabel('x')
plt.ylabel('p(x)')
plt.legend()

<matplotlib.legend.Legend at 0x7ff6707ea0a0>