Tutorial on Automatic Differentiation¶
(by Lukas Heinrich. See: pyhep2020-autodiff-tutorial )
Introduction¶
Welcome to this tutorial on automatic differentiation. Automatic Differentiation is a method to compute exact derivatives of functions implements as programs. It’s a widely applicable method and famously is used in many Machine learning optimization problems. E.g. neural networks, which are parametrized by weights \(\text{NN}(\text{weights})\) are trained by (stocastic) gradient descent to find the minimum of the loss function \(L\) where
This means that efficient algorithms to compute derivatives are crucial.
Aside from ML, many other use-cases require gradients: standard statistical analysis in HEP (fitting, hypothesis testing, …) requires gradients. Uncertainty propagation (e.g. track parameters) uses gradients, etc..
import pyhf
pyhf.set_backend('jax')
import jax
import jaxlib
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
Other approaches to differentiation¶
Before diving into automatic differentiation, let’s review how my might otherwise compute derivatives
Finite Differences¶
A common appraoch to approximate gradients of a black-box function is to evaluate it at close-by points \(x\) and \(x+Δx\) and
\(\frac{\partial f}{\partial x} \approx \frac{f(x) - f(x+\Delta x}{\Delta x}\) if \(\Delta x\) is sufficiently small
def black_box_func(x):
return x**3+30
def true_gradient_func(x):
return 3*x**2
def plot_gradients(nsteps,title):
xi = np.linspace(-5,5,nsteps)
yi = black_box_func(xi)
approx_gradient = np.gradient(yi,xi)
true_gradient = true_gradient_func(xi)
plt.plot(xi,yi, label = 'black-box func')
plt.scatter(xi,yi)
plt.plot(xi,approx_gradient, label = 'finite diff grad')
plt.scatter(xi,approx_gradient)
plt.plot(xi,true_gradient, label = 'true grad')
plt.scatter(xi,true_gradient)
plt.legend()
plt.title(title)
plt.show()
plot_gradients(11, title = 'it is pretty bad if Δx is too large')
plot_gradients(41, title = 'it gets better at the cost of many evaluations')
while only approximate, finite differences is *simple*. I don't need to know anything about the function beyond having the ability to *evaluate* it
This way I can compute gradients of functions encoded as a computer program, and it works in any programming language
For multivariate (possibly vector-valued) functions $\vec{f}(\vec{x}) = f_i(x_1,x_2,\dots,x_n)$ one needs to compute a finite difference gradient for each partial derivative $\frac{\partial f}{\partial x}$ in order to get the full jacobian / total derivative $df_i = J_{ik} dx_k\; J_{ik} = \frac{\partial f_i}{\partial x_k}$
In high dimensions, the number of required evaluations explodes!
Finite Differences:
Pro: easy to to, works in any language, no “framework needed”
Con: inaccurate unless one does a lot of evaluations
Con does not scale to large dimensions
Symbolic Differentiation in a CAS¶
Computer Algebra Systems (CAS), such as Mathematica (or sympy) can manipulate functional expressions and know about differentiation rules (and many other things)
If the function / the prograrm which we want to derive is available as such an expression the symbolic differentiation can produce exact gradients
import sympy
def function(x):
return x**3
def true_deriv(x):
return 3*x**2
symbolic_x = sympy.symbols('x')
symbolic_func = function(symbolic_x)
symbolic_func
Using lambdify
we can turn it into a normal python function we can evaluate
xi = np.linspace(-5,5,11)
yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi)
<matplotlib.collections.PathCollection at 0x7f96f85724c0>
symbolic_func
is now an experssion which we can differentiate symbolically
symbolic_deriv = symbolic_func.diff(symbolic_x)
symbolic_deriv
def plot_symbolic(nsteps,title):
xi = np.linspace(-5,5,nsteps)
yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)
plt.scatter(xi,yi)
plt.plot(xi,yi, label = 'function')
yi = true_deriv(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi, label = 'true deriv')
yi = sympy.lambdify(symbolic_x,symbolic_deriv)(xi)
plt.plot(xi,yi)
plt.scatter(xi,yi, label = 'symbolic deriv')
plt.legend()
plt.title(title)
plt.show()
plot_symbolic(11,title = 'the symbolid derivative is always exact')
plot_symbolic(4, title = 'it does not matter where/how often you evaluate it')
Chain Rule in CAS¶
We can even handle function compositions
def f1(x):
#standard operations are overloaded
return x**2
def f2(x):
#note here we use a special cos function from sympy
#instead of e.g. np.cos or math.cos
return sympy.cos(x)
composition = f2(f1(symbolic_x))
composition
composition.diff(symbolic_x)
Since sympy
knows about the chain rule it can differentiate accordingly
Problems with Symbolic Differentiation¶
This looks great! We get exact derivatives. However, there are drawbacks
Need to implement it in CAS
Most functions we are interested in are not implemented e.g. Mathematica. Rather we have loads of C, C++, Python code that we are interested in.
But ok, sympy
alleviates this to some degree. The functions
f1
and f2
are fairly generic since they use operator
overloading. So a symbolic program and a “normal” program
could only differ by a few import statements
from sympy import cos
def f1(x):
return x**2
def f2(x):
return cos(x)
versus:
from math import cos
def f1(x):
return x**2
def f2(x):
return cos(x)
Note the code is almost exactly the same
But not all our functions are so simple!
Expression swell
Let’s look at a quadratic map which is applied a few times
def quadmap(x):
return x**2 + 3*x + 4
def func(x):
for i in range(6):
x = quadmap(x)
return x
quad_6_times = func(symbolic_x)
quad_6_times
This looks pretty intimidating. What happened? Symbolic programs run through the prgram and accumulate the full program into a single expression
If we would just blindly differentiate this it would look like this
quad_6_times.diff(symbolic_x)
This looks even worse!
Also note that that if we just blindly substitute x for some value e.g. x=2, we would be computing a lot of the same terms manyt times. E.g. in the above expression \(x^2+3x+4\) appears in a lot of places due to the “structure’ of the original progrm
If you knew the structure of the program you likely could precompute some of these repeating terms. However once it got all expanded all this knowledge about the structure is gone!
Modern CAS can recover some of this by finding “common subexpressions” (CSE)
sympy.cse(quad_6_times)
([(x0, x**2),
(x1, (3*x + x0 + 4)**2),
(x2, (9*x + 3*x0 + x1 + 16)**2),
(x3, (27*x + 9*x0 + 3*x1 + x2 + 52)**2),
(x4, (81*x + 27*x0 + 9*x1 + 3*x2 + x3 + 160)**2)],
[729*x + 243*x0 + 81*x1 + 27*x2 + 9*x3 + 3*x4 + (243*x + 81*x0 + 27*x1 + 9*x2 + 3*x3 + x4 + 484)**2 + 1456])
But it’s not as automatic and may note find all relevant subexpressions. In any case it’s trying hard to recover some of the structure that is already implicitly present in the prograam we want to differentiate
Control Flow
In addition to looping constucts like above, a lot of the functions we are interested in have control flow structures like if/else statements, while loops, etc..
If we try to create a symbolic expression with conditionals we fail badly
def func(x):
if x > 2:
return x**2
else:
return x**3
try:
symbolic_result = func(symbolic_x)
except TypeError as err:
print(err)
cannot determine truth value of Relational
That’s too bad because this is a perfectly respectable function almost everywhere
xi = np.linspace(-2,5,1001)
yi = np.asarray([func(xx) for xx in xi])
plt.plot(xi,yi)
plt.scatter(xi,yi)
plt.title("pretty smooth except at x=2")
plt.show()
If we could afford finite diffences it would compute gradients just fine.
g = np.gradient(yi,xi)
plt.plot(xi,g)
plt.scatter(xi,g)
plt.ylim(-2,10)
plt.title('''\
parabolesque gradient in x^3 region,
linear in x^2 region as expected''');
In short: symbolic differentiation is not our saving grace.
Pro: Gradients are exact, if you can compute them
Con: Need to implement in CAS. Full-featured Cas not easily available in all languages
Con: lead to expression swell by losing any structure of the program (needs to be recovered separately0
Con: Cannot handle common control-flow structures like loops and conditionals easily
What we need¶
To recap:
Finite differences is
easy to implement in any language
handles arbitrary (halting) programs but
is inaccurate unless we’re ready to pay a large computational overhead
Symbolic differentiation is:
exact to machine precision
can lead to exccessive / inefficient computation if not careful
cannot handle complex programs with control flow structures
So what we need is a third approach!
One, that is
exact
efficient
can handle arbitrayr programs
that is easy to implement in many languages
This third approach is ‘Automatic’ differentiation.
Short Interlude on Linear Transformations¶
Before we start, let’s first look at linear transformations* from ℝᵐ → ℝⁿ:
With a given basis, this is representable as a (rectangular0 matrix:
For a given linear problem, there are few ways we can run this computation
full matrix computation
i.e. we store the full (dense) \(nm\) elements of the rectangular matrix and compute an explicit matrix multiplication.
The computation can be fully generic for any matrix
def result(matrix, vector):
return np.matmul(matrix,vector)
sparse matrix computation
If many \(A_{ij}=0\), it might be wasteful to expend memory on them. We can just create a sparse matrix, by
storing only the non-zerro elements
storing a look-up table, where those elements are in the matrix
The computation can be kept general
def result(sparse_matrix, vector):
return sparse_matmul(sparse_matrix,vector)
matrix-free computation
In many cases a linear program is not explicitly given by a Matrix, but it’s given as code / a “black-box” function. As long as the computation in the body of keeps to (hard-coded) linear transformation the program is linear. The matrix elements are no longer explicitly enumerated and stored in a data structure but implicitly defined in the source code.
This is not anymore a generic computation, but each linear transformation is its own program. At the same time this is also the most memory efficient representation. No lookup table is needed since all constants are hard-coded.
def linear_program(vector):
z1,z2 = 0,0
z1 += A_11*x1
z2 += A_12*x2
z2 += A_22*x2
return [z1,z2]
Recovering Matrix Elements from matrix-free computations¶
Matrix-vector products¶
In the matrix-free setting, the program does not give access to the matrix elements, but only computes “matrix-vector products” (MVP)
We can use basis vectors to recover the matrix one column at a time
def matrix_vector_product(x):
x1,x2,x3 = x
z1,z2 = 0,0
z1 += 2*x1 #MVP statement 1
z2 += 1*x2 #MVP statement 2
z2 += 3*x3 #MVP statement 3
return np.asarray([z1,z2])
M = np.concatenate([
matrix_vector_product(np.asarray([1,0,0])).reshape(-1,1),
matrix_vector_product(np.asarray([0,1,0])).reshape(-1,1),
matrix_vector_product(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from matrix-vector products:\n{M}')
M derived from matrix-vector products:
[[2 0 0]
[0 1 3]]
Vector Matrix product (VMP)¶
The same matrix induces a “dual” linear map: ℝⁿ → ℝᵐ $\( x_k = y_i A_{ik}\)$
i.e. instead of a Matrix-vector product it’s now a vector-Matrix product (VMP)
If one has access to a “vector-Matrix” program corresponding to a matrix \(A\) one can again – as in the MVP-case – recover the matrix elements, by feeding in basis vectors.
This time the matrix is built one row at a time
def vector_matrix_product(z):
x1,x2,x3 = 0,0,0
z1,z2 = z
x3 += z2*3 #VMP version of statement 3
x2 += z2*1 #VMP version of statement 2
x1 += z1*2 #VMP version of statement 1
return np.asarray([x1,x2,x3])
M = np.concatenate([
vector_matrix_product(np.asarray([1,0])).reshape(1,-1),
vector_matrix_product(np.asarray([0,1])).reshape(1,-1),
],axis=0)
print(f'M derived from vector-matix products:\n{M}')
M derived from vector-matix products:
[[2 0 0]
[0 1 3]]
Short Recap:¶
For a given linear transformation, characterized by a matrix \(A_{ij}\) we have a forward (matrix-vector) and backward (vector-matrix) map $\(y_i = A_{ij}x_k\)\( \)\(x_j = y_i A_{ij}\)$
and we can use either to recover the full matrix \(A_{ij}\)
Wide versus Tall Transformation¶
If you look at the code above, you’ll notice that the number of calls necessary to the MVP or VMP program is related to the dimensions of matrix itself.
For a \(n\times m\) matrix (for a map: ℝᵐ → ℝⁿ), you need as \(m\) calls to the “Matrix-vector” program to built the full matrix one-column-at-a-time. Likewise you need \(n\) calls to the “vector-Matrix” program to build the matrix one-row-at-a-time.
This becomes relevant for very asymmetric maps: e.g. scalar maps from very high-dimensional spaces \(\mathbb{R}^{10000} \to \mathbb{R}\) the “vector-Matrix” appraoch is vastly more efficient than the “Matrix-vector one. There’s only one row, so only one call too the VMP program is needed to construct the full matrix!
Similarly, functions mapping few variables into very high dimensional spaces \(\mathbb{R} \to \mathbb{R}^{10000}\) it’s the opposite: the “Matrix-vector” approach is much better suited than the “vector-Matrix” one (this time it’s a single column!).
Function Compositions¶
Of course copositions \((f\circ g)(x) = f(g(x))\) of linear maps are also linear, so the above applies.
Depending on whether the “Matrix-vector” or “vector-Matrix” appraoch is used, the data is propagated forwards or backwards.
Forward |
Backward |
---|---|
From Matrices to Graphs¶
The “vector-Matrix” or “Matrix-vector” picture can be generalized to arrbitrary directed acyclic graphs.
In the “Matrix-vector” picture the node value is the edge-weighted sum of the “upstream nodes”.
In the “vector-Matrix” picture the node value is the edge-weighted sum of its “downstream nodes”.
(one could in principle always recove a rectangular/matrix-like version of a DAG by inserting trivial nodes)
| | | :———- : | : —— : |
def graph_like(x):
x1,x2,x3 = x
y1 = 2*x1+x2
z1,z2 = y1+2*x3,x3-y1 #note that we reach "over" the "ys" and diectly touch x_n
return np.asarray([z1,z2])
def matrix_like(x):
x1,x2,x3 = x
y1 = 2*x1+x2
y2 = x3 #can just introduce a dummy variable to make it matrix-like
z1,z2 = y1+2*x3,y2-y1
return np.asarray([z1,z2])
M = np.concatenate([
matrix_like(np.asarray([1,0,0])).reshape(-1,1),
matrix_like(np.asarray([0,1,0])).reshape(-1,1),
matrix_like(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from matrix like computation:\n{M}')
M = np.concatenate([
graph_like(np.asarray([1,0,0])).reshape(-1,1),
graph_like(np.asarray([0,1,0])).reshape(-1,1),
graph_like(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from graph-like products:\n{M}')
M derived from matrix like computation:
[[ 2 1 2]
[-2 -1 1]]
M derived from graph-like products:
[[ 2 1 2]
[-2 -1 1]]
Derivatives¶
Why are we talking about linear transformations? After all lot of the code we write is non-linear! However, derivatives are always linear.
And derivatives (the jacobian) of a composition \(f\circ g\) is the composition of linear derivatives (the jacobians of each map) i.e. the full jacobian Matrix is the result of multipying all Jacobians of the composition. $\(J = J_0 J_1 J_2 J_3 \dots J_n \)$
(This is just the chain rule) $\(z = f(y) = f(g(x))\hspace{1cm} \frac{\partial f_i}{\partial x_j} = \frac{\partial f_i}{\partial z_j}\frac{\partial z_j}{\partial x_k}\)$
I.e. finding derivatives, means characterizing the jacobian matrix. From the above discussion, we can use the “Jacobian-vector product” (JVP, builds Jacobians column-wise) or “vector-Jacobian product” (builds Jacobians row-wise) approach.
In the language of automatic differentiation
Jacobian-vector products (JVP) = forward mode (forward propagation)
vector-Jacobian products (VJP) = reverse mode (reverse propagation)
Example¶
Let’s work this out on a very simple problem
In the forward pass we use “Matrix-vector” products and need to do two evaluation
In the backward pass we use “vector-Matrix” products and need to do only a single evaluation
Both approaches give the same result. Since this is a map from \(\mathbb{R}^2 \to \mathbb{R}^1\) the backward pass is more efficient than the forward pass
Let’s look at a real-life example
This is easy python code
def mul_func(x1,x2):
return x1*x2
def sum_func(x1,x2):
return x1+x2
def function(x):
x1,x2 = x
y = mul_func(x1,x2)
z = sum_func(y,x2)
return z
print(function([2,4]))
12
In the forward pass, an autodiff system needs to create a JVP implementation for each elementary operation
def mul_jvp(x1,dx1,x2,dx2):
y = mul_func(x1,x2)
dy = x1*dx2 + x2*dx1
return y, dy
def sum_jvp(x1,dx1,x2,dx2):
return sum_func(x1,x2), dx1 + dx2
def function_jvp(x,dx):
x1,x2 = x
dx1,dx2 = dx
y, dy = mul_jvp(x1,dx1,x2,dx2)
z, dz = sum_jvp(y,dy, x2, dx2)
return z,dz
Since in the forward pass we build “column-at a time” and our final jacobian is has shape (1x2), i.e. two columns we need two forward passes to get the full Jacobian. Not that for eacch forward pass we also get the fully computed functino value delivered on top!
Also note that the “JVP” version of the functino has the same structure as the original function. For each call in the original program there is an equivalent call in the JVP program. However the JVP call does always two things at once
compute the nominal result
compute the differentials
So it has roughly 2x the run-time as the original program (depending on the complexity of the derivatives). Said another way: computing the one-pass in the derivative has the same computational complexity as the function itself.
print(function_jvp([2,4],[1,0]))
print(function_jvp([2,4],[0,1]))
(12, 4)
(12, 3)
For the backward pass we build “row-at-a-time’. For each elementary operation we need to build a VJP implementation
def mul_vjp(x1,x2,dx1,dx2,dout):
dx2 += dout * x1
dx1 += dout * x2
return dx1,dx2
def sum_vjp(x1,x2,dx1,dx2,dout):
dx1 += dout * 1
dx2 += dout * 1
return dx1,dx2
def function_vjp(x,dz):
#run forward
x1,x2 = x
y = mul_func(x1,x2)
z = sum_func(y,x2)
#zero gradients
dy = 0
dx1 = 0
dx2 = 0
#run backward
dy,dx2 = sum_vjp(y,x1, dy, dx2, dz)
dx1,dx2 = mul_vjp(x1,x2, dx1, dx2, dy)
return z,[dx1,dx2]
Here, we see the power of backward propagation (or the reverse mode) we get all gradients of the single row ine oone go. Since this Jacobian only has one row, we’re done! And we get the function value delivered on top of the gradients as well!
print(function_vjp([2,4],1.0))
(12, [4.0, 3.0])
Again, let’s look at the “VJP” code. The forward pass is exactly the same as the original function. This just records the final result and all intermediate values, which we will need for the backward pass.
Moving on to the backward pass, we see again, as in JVP, it has the same structure as the forward pass. For each call to a subroutine there is an equivalent call in the backward pass to compute the VJP.
As in the JVP case, the computational complexity of one backward pass is roughly the same as the forward pass. Now unlike the JVP-case we only needed a single pass for all the gradients of this scalar function. So obtaining the full gradient of a function is only as expensive as the function itself.
Recap:¶
Above we have built a manual autodiff system. Let’s recap what we needed to do
define a set of operations we want to be differentiable
define sub-routines for nominal operations, JVP and VJP
Once given a program, we had to do the following
In the forward mode:
just replace the nominal function with the JVP one
for each variable in the program allocate a “differential” variable and pass it into the JVP whereever we also pass the nominal variable
In the backward mode:
Run the program forward, keep track of all values
keep track of the order of operations on a “record” of sorts
allocate “differential” variables for all values and initialize to zero
use the record to replay the order of operations backwards, passing along the appropriate differential values, and updating the relevant ones with the result of the VJP
All of this is pretty mechanistic and hence “automatable”. And given that it’s a very narrow domain of only implementing JVP/JVP operations this is easy to do in any language.
That’s why it’s automatic differentiation
What we gain from this is that we get
exact derivatives (to machine precision) for arbitrary composed of the operations we define
complexity of a derivative-pass through the program is of same order of complexity as the original program
often only a single pass is necessary (e.g. scalar multi-variate functions)
unlike symbolic differrentiation, the structure of the program is preserved and allows naturally to avoid repetitive calculations of the same values
(we will see that) arbitrary control flows are handles naturally
it’s something that is easy for a comoputer do and for a progarmmer to imlpement
Some notes on pros and cons:
In the forward mode:
the signature of each opeartion basically extends
float f(float x,float y,float z)
to
pair<float> f(float x,float dx,float y,float float dy, float z,float dz)
if you use composite types (“dual numbers”) that hold both x,dx you can basically keep the signature unchanged
f(dual x, dual x, dual z)
together with operator overloading on these dual types e.g.
dual * dual
you can essentially keep the source code unchangedfloat f(float x, float y): return x*y
->
dual f(dual x,dual y): return x*y
That means it’s very easy implement. And memory efficient, no superfluous values are kept when they run out of scope.
But forward more better for vector-value functions of few parameters
In the reverse mode:
very efficient, but we need to keep track of order (need a “tape” of sorts)
since we need to access all intermediate varriables, we can run into memory bounds
the procedurer is a bit more complex than fwd: 1) run fwd, 2) zero grads 3) run bwd
I don’t want to implement an autodiff system.. Aren’t there libraries for this??¶
Yes there are! And a lot of them in many languages. On the othe rhand, try finding CAS systems in each of those
This is PyHEP, so let’s focus on Python. Here, basically what you think of as “Machine Learning frameworks” are at the core autodiff libraries
Tensorflow
PyTorch
JAX
Let’s focus on jax
import jax
import jax.numpy as jnp
def f(x):
return x**2
jax.numpy
is almost a drop-in rerplacement for numpy
. I do import jax.numpy as jnp
but if you’re daring you could do import jax.numpy as np
x = jnp.array([1,2,3])
y = jnp.array([2,3,4])
/Users/cranmer/anaconda3/envs/stats-book-2/lib/python3.8/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
print(x+y)
print(x*y)
print(jnp.log(x))
print(jnp.exp(y))
[3 5 7]
[ 2 6 12]
[0. 0.69314718 1.09861229]
[ 7.3890561 20.08553692 54.59815003]
def f(x):
return x**3
print(f(4.0))
print(jax.grad(f)(4.0)) #boom!
print(jax.grad(jax.grad(f))(4.0)) #boom!
print(jax.grad(jax.grad(jax.grad(f)))(4.0)) #boom!
print(jax.grad(jax.grad(jax.grad(jax.grad(f))))(4.0)) #boom!
64.0
48.0
24.0
6.0
0.0
xi = jnp.linspace(-5,5)
yi = f(xi)
plt.plot(xi,yi)
[<matplotlib.lines.Line2D at 0x7f972aeb56a0>]
try:
jax.grad(f)(xi)
except TypeError as err:
print(err)
Gradient only defined for scalar-output functions. Output had shape: (50,).
Whoops, jax.grad defaults to reverse mode with a single backward pass, but through broadcasting we get a vector -> vector
map. We can use some jax magic to “unbroadcast” the function, take the gradient and re-broadcast it
jax.vmap(jax.grad(f))(xi)
DeviceArray([7.50000000e+01, 6.90024990e+01, 6.32548938e+01,
5.77571845e+01, 5.25093711e+01, 4.75114536e+01,
4.27634319e+01, 3.82653061e+01, 3.40170762e+01,
3.00187422e+01, 2.62703040e+01, 2.27717618e+01,
1.95231154e+01, 1.65243648e+01, 1.37755102e+01,
1.12765514e+01, 9.02748855e+00, 7.02832153e+00,
5.27905040e+00, 3.77967514e+00, 2.53019575e+00,
1.53061224e+00, 7.80924615e-01, 2.81132861e-01,
3.12369846e-02, 3.12369846e-02, 2.81132861e-01,
7.80924615e-01, 1.53061224e+00, 2.53019575e+00,
3.77967514e+00, 5.27905040e+00, 7.02832153e+00,
9.02748855e+00, 1.12765514e+01, 1.37755102e+01,
1.65243648e+01, 1.95231154e+01, 2.27717618e+01,
2.62703040e+01, 3.00187422e+01, 3.40170762e+01,
3.82653061e+01, 4.27634319e+01, 4.75114536e+01,
5.25093711e+01, 5.77571845e+01, 6.32548938e+01,
6.90024990e+01, 7.50000000e+01], dtype=float64)
that looks better!
jax.grad(f)
just returns another function. Of course we can just
take the gradient of that as well. And so on…
g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.plot(xi,g2i, label = "f''")
plt.plot(xi,g3i, label = "f'''")
plt.legend()
<matplotlib.legend.Legend at 0x7f97007d0c40>
Control Flow¶
Back when discussing symbolic differentiation we hit a snag when adding control flow through to our prorgam. In Jax this just passes through transparently.
Let’s compare this to finite differences. So far the only system we had to compute derivatives of control-flow-ful programs
def control_flow_func(x):
if x > 2:
return x**2
else:
return x**3
first_gradient_of_cflow = jax.grad(control_flow_func)
xi = jnp.linspace(-2,5,101)
yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])
plt.plot(xi,yi,c = 'k')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])
plt.scatter(xi,yi, label = 'jax autodiff')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([control_flow_func(xx) for xx in xi])
plt.scatter(xi,np.gradient(yi,xi), label = 'finite differences')
plt.legend()
<matplotlib.legend.Legend at 0x7f9718b2dee0>
We can start to see the benefits autodiff. Among other things, finite differnces becomes quite sensitive to exactly where the evaluation points are (e.g. wrt to the discontinuity)
As we compute higher derivatives, this error compounds badly for finite differences. But for autodiff, it’s smooth sailing!
second_gradient_of_cflow = jax.grad(first_gradient_of_cflow)
xi = jnp.linspace(-2,5,101)
yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])
plt.plot(xi,yi,c = 'k')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])
plt.scatter(xi,yi, label = '2nd deriv jax autodiff')
xi = jnp.linspace(-2,5,11)
yi = np.asarray([control_flow_func(xx) for xx in xi])
plt.scatter(xi,np.gradient(np.gradient(yi),xi), label = '2nd deriv finite differences',)
plt.legend()
<matplotlib.legend.Legend at 0x7f96f8748f70>
Custom Operations¶
Not all our programs are so simple. Consider this
def func(x)
y_root = solve(x^2 + y^2 == 1,x = x, y_start = 2.0)
return y_root
solving this often goes through some iterative algorithm like Brent bracketing But, differentiating through the iteration is not the right solution.
We can add our own custom gradients
Recall the implicit function theorem $\( f(x,y) = x^2 + y^2 -1 = 0 \)$
How do we teach this an autodiff system:
Recall:
we can choose which operations we consider “fundamental”
we don’t need to constrain ourselves to the lowest possible representationo
import jax
from jax import core
import numpy as np
from jax.interpreters import ad
import scipy
import functools
import matplotlib.pyplot as plt
def findroot(f):
return scipy.optimize.brentq(f,a = 0,b = 10)
def func(x,y):
return x**2 + y**2 - 1
def y_for_x(x):
return findroot(functools.partial(func,x))
xi = np.linspace(-1,1)
yi = np.asarray([y_for_x(xx) for xx in xi])
plt.plot(xi,yi)
findrootjax_p = core.Primitive('findrootjax')
findrootjax_p.def_impl(lambda x: y_for_x(x))
ad.defvjp(findrootjax_p, lambda g, x: - x / y_for_x(x))
def findrootjax(x):
return findrootjax_p.bind(x)
jax.value_and_grad(findrootjax)(0.5)
xi = np.linspace(-1,1,101)
yi = [findrootjax(v) for v in xi]
plt.plot(xi,yi)
xi = np.linspace(-1,1,21)
vg = np.asarray([np.asarray(jax.value_and_grad(findrootjax)(v)) for v in xi])
plt.scatter(xi,vg[:,0])
plt.quiver(xi,vg[:,0],np.ones_like(vg[:,0]),vg[:,1],
angles = 'uv',
alpha = 0.5,
)
plt.gcf().set_size_inches(5,2.5)
In HEP¶
Of course we can use automatic differentiation for neural networks. But other things in HEP also can make use of gradients. A prime example where this is the case is statistical analysis
For a maximum likelihood fit we want to minimize the log likelihood
\(\theta^* = \mathrm{argmin}_\theta(\log L)\)
import jax
import jax.numpy as jnp
import numpy as np
import pyhf
import matplotlib.pyplot as plt
pyhf.set_backend('jax')
m = pyhf.simplemodels.hepdata_like([5.],[10.],[3.5])
pars = jnp.array(m.config.suggested_init())
data = jnp.array([15.] + m.config.auxdata)
m.logpdf(pars,data)
DeviceArray([-4.25748227], dtype=float64)
bestfit = pyhf.infer.mle.fit(data,m)
bestfit
DeviceArray([1., 1.], dtype=float64)
grid = x,y = np.mgrid[.5:1.5:101j,.5:1.5:101j]
points = np.swapaxes(grid,0,-1).reshape(-1,2)
v = jax.vmap(m.logpdf, in_axes = (0,None))(points,data)
v = np.swapaxes(v.reshape(101,101),0,-1)
plt.contourf(x,y,v, levels = 100)
plt.contour(x,y,v, levels = 20, colors = 'w')
grid = x,y = np.mgrid[.5:1.5:11j,.5:1.5:11j]
points = np.swapaxes(grid,0,-1).reshape(-1,2)
values, gradients = jax.vmap(
jax.value_and_grad(
lambda p,d: m.logpdf(p,d)[0]
), in_axes = (0,None)
)(points,data)
plt.quiver(
points[:,0],
points[:,1],
gradients[:,0],
gradients[:,1],
angles = 'xy',
scale = 75
)
plt.scatter(bestfit[0],bestfit[1], c = 'r')
plt.xlim(0.5,1.5)
plt.ylim(0.5,1.5)
plt.gcf().set_size_inches(5,5)