Revisiting error propagation with automatic differentiation
By Kyle Cranmer, March 2, 2020
This notebook is dedicated to Feeman Dyson, who died on February 28, 2020 in Princeton, NJ at the age of 96.
“New directions in science are launched by new tools much more often than by new concepts. The effect of a concept-driven revolution is to explain old things in new ways. The effect of a tool-driven revolution is to discover new things that have to be explained.”
-- Freeman Dyson
)
Reminder of propagation of errors
This notebook was made to investigate the propagation of errors formula. We imagine that we have a function $q(x,y)$ and we want to propagate the uncertainty on $x$ and $y$ (denoted $\sigma_x$ and $\sigma_y$, respectively) through to the quantity $q$.
The most straight forward way to do this is just randomly sample $x$ and $y$, evaluate $q$ and look at it's distribution. This is really the definition of what we mean by propagation of uncertianty. It's very easy to do with some simply python code.
The calculus formula for the propagation of errors is really an approximation. This is the formula for a general $q(x,y)$ \begin{equation} \sigma_q^2 = \left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2 \end{equation}
In the special case of addition $q(x,y) = x\pm y$ we have $\sigma_q^2 = \sigma_x^2 + \sigma_y^2$.
In the special case of multiplication $q(x,y) = x y$ and division $q(x,y) = x / y$ we have $(\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2$, which we can rewrite as $\sigma_q = (x/y) \sqrt{(\sigma_x/x)^2 + (\sigma_y/y)^2}$
Let's try out these formulas and compare the direct approach of making the distribution to the prediction from these formulas
Automatic Differentiation
Excerpts from the Wikipedia article: https://en.wikipedia.org/wiki/Automatic_differentiation
In mathematics and computer algebra, automatic differentiation (AD), also called algorithmic differentiation or computational differentiation,[1][2] is a set of techniques to numerically evaluate the derivative of a function specified by a computer program. AD exploits the fact that every computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (addition, subtraction, multiplication, division, etc.) and elementary functions (exp, log, sin, cos, etc.). By applying the chain rule repeatedly to these operations, derivatives of arbitrary order can be computed automatically, accurately to working precision, and using at most a small constant factor more arithmetic operations than the original program.
Usually, two distinct modes of AD are presented, forward accumulation (or forward mode) and reverse accumulation (or reverse mode). Forward accumulation specifies that one traverses the chain rule from inside to outside (that is, first compute ${\displaystyle dw_{1}/dx}$ and then ${\displaystyle dw_{2}/dw_{1}}$ and at last ${\displaystyle dy/dw_{2}})$, while reverse accumulation has the traversal from outside to inside (first compute ${\displaystyle dy/dw_{2}}$ and then ${\displaystyle dw_{2}/dw_{1}}$ and at last ${\displaystyle dw_{1}/dx})$. More succinctly,
forward accumulation computes the recursive relation: ${\displaystyle {\frac {dw_{i}}{dx}}={\frac {dw_{i}}{dw_{i-1}}}{\frac {dw_{i-1}}{dx}}}$ with ${\displaystyle w_{3}=y}$, and,
reverse accumulation computes the recursive relation: ${\displaystyle {\frac {dy}{dw_{i}}}={\frac {dy}{dw_{i+1}}}{\frac {dw_{i+1}}{dw_{i}}}}$ with ${\displaystyle w_{0}=x}$.
We will use https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
from jax import grad, jacfwd
import jax.numpy as np
Now here are 3 lines of code for the propagation of uncertainty formula
\begin{equation} \sigma_q = \sqrt{\left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2} \end{equation}def error_prop_jax_gen(q,x,dx):
jac = jacfwd(q)
return np.sqrt(np.sum(np.power(jac(x)*dx,2)))
x_ = np.array([2.,3.])
dx_ = np.array([.1,.1])
def q(x):
return x[0]+x[1]
def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = dx[0]**2 + dx[1]**2
return np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
Multiplication and Division
In the special case of multiplication \begin{equation} q(x,y) = x y \end{equation} and division \begin{equation} q(x,y) = \frac{x}{y} \end{equation}
\begin{equation} (\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2 \end{equation}which we can rewrite as \begin{equation} \sigma_q = (x/y) \sqrt{\left(\frac{\sigma_x}{x}\right)^2 + \left(\frac{\sigma_y}{y}\right)^2} \end{equation}
def q(x):
return x[0]*x[1]
def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2
return (x[0]*x[1])*np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
def q(x):
return x[0]/x[1]
def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2
return (x[0]/x[1])*np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
Powers
$q(x,y) = x^m y^n$ we have
\begin{equation} (\sigma_q/q)^2 = \left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2 \end{equation}which we can rewrite as \begin{equation} \sigma_q = x^m y^n \sqrt{\left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2} \end{equation}
def q(x, m=2, n=3):
return np.power(x[0],m)*np.power(x[1],n)
x_ = np.array([1.5, 2.5])
dx_ = np.array([.1, .1])
q(x_)
def error_prop_classic(x, dx):
# for q = x[0]*x[1]
dq_ = q(x_)*np.sqrt(np.power(2*dx_[0]/x_[0],2)+np.power(3*dx_[1]/x_[1],2))
return dq_
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
Misc Examples
See some examples here:
http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html
Example: w = (4.52 ± 0.02) cm, A = (2.0 ± 0.2), y = (3.0 ± 0.6) cm
. Find
The second relative error, (Dy/y), is multiplied by 2 because the power of y is 2.
The third relative error, (DA/A), is multiplied by 0.5 since a square root is a power of one half.
So Dz = 0.49 (28.638 ) = 14.03 which we round to 14
z = (29 ± 14) Using Eq. 3b, z=(29 ± 12) Because the uncertainty begins with a 1, we keep two significant figures and round the answer to match.
def q(x):
return x[0]*x[2]*x[2]/np.sqrt(x[1])
x_ = np.array([4.52, 2., 3.]) #w,A,y
dx_ = np.array([.02, .2, .6])
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
import numpy as onp #using jax as np right now
w_ = onp.random.normal(x_[0], dx_[0], 10000)
A_ = onp.random.normal(x_[1], dx_[1], 10000)
y_ = onp.random.normal(x_[2], dx_[2], 10000)
x__ = np.vstack((w_, A_, y_))
z_ = q(x__)
print('mean =', np.mean(z_), 'std = ', np.std(z_))
import matplotlib.pyplot as plt
_ = plt.hist(z_, bins=50)
also taken from http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html
w = (4.52 ± 0.02) cm, x = (2.0 ± 0.2) cm, y = (3.0 ± 0.6) cm.
Find \begin{equation} z = w x +y^2 \end{equation}
We have v = wx = (9.0 ± 0.9) cm.
The calculation of the uncertainty in is the same as that shown to the left. Then from Eq. 1b
Dz = 3.7
z = (18 ± 4) .
def q(x):
# [w,x,y]
return x[0]*x[1]+x[2]*x[2]
x_ = np.array([4.52, 2., 3.]) #w,x,y
dx_ = np.array([.02, .2, .6])
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
def q(x):
return np.sum(x)
x_ = 1.*np.arange(1,101) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(100)
The sum from $1 to N$ is $N*(N+1)/2$ (see the story of Gauss), so we expect q(x)=5050. And the uncertainty should be $\sqrt{100}*0.1$ = 1.
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
another toy example... product from 1 to 10
def q(x):
return np.product(x)
x_ = 1.*np.arange(1,11) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(10)
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
Checking this is an exercise left to the reader :-)