Revisiting error propagation with automatic differentiation

By Kyle Cranmer, March 2, 2020

This notebook is dedicated to Feeman Dyson, who died on February 28, 2020 in Princeton, NJ at the age of 96.

“New directions in science are launched by new tools much more often than by new concepts. The effect of a concept-driven revolution is to explain old things in new ways. The effect of a tool-driven revolution is to discover new things that have to be explained.”

– Freeman Dyson

Reminder of propagation of errors

This notebook was made to investigate the propagation of errors formula. We imagine that we have a function \(q(x,y)\) and we want to propagate the uncertainty on \(x\) and \(y\) (denoted \(\sigma_x\) and \(\sigma_y\), respectively) through to the quantity \(q\).

The most straight forward way to do this is just randomly sample \(x\) and \(y\), evaluate \(q\) and look at it’s distribution. This is really the definition of what we mean by propagation of uncertianty. It’s very easy to do with some simply python code.

The calculus formula for the propagation of errors is really an approximation. This is the formula for a general \(q(x,y)\)

(21)\[\begin{equation} \sigma_q^2 = \left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2 \end{equation}\]

In the special case of addition \(q(x,y) = x\pm y\) we have \(\sigma_q^2 = \sigma_x^2 + \sigma_y^2\).

In the special case of multiplication \(q(x,y) = x y\) and division \(q(x,y) = x / y\) we have \((\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2\), which we can rewrite as \(\sigma_q = (x/y) \sqrt{(\sigma_x/x)^2 + (\sigma_y/y)^2}\)

Let’s try out these formulas and compare the direct approach of making the distribution to the prediction from these formulas

Automatic Differentiation

Excerpts from the Wikipedia article: https://en.wikipedia.org/wiki/Automatic_differentiation

In mathematics and computer algebra, automatic differentiation (AD), also called algorithmic differentiation or computational differentiation,[1][2] is a set of techniques to numerically evaluate the derivative of a function specified by a computer program. AD exploits the fact that every computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (addition, subtraction, multiplication, division, etc.) and elementary functions (exp, log, sin, cos, etc.). By applying the chain rule repeatedly to these operations, derivatives of arbitrary order can be computed automatically, accurately to working precision, and using at most a small constant factor more arithmetic operations than the original program.

Usually, two distinct modes of AD are presented, forward accumulation (or forward mode) and reverse accumulation (or reverse mode). Forward accumulation specifies that one traverses the chain rule from inside to outside (that is, first compute \({\displaystyle dw_{1}/dx}\) and then \({\displaystyle dw_{2}/dw_{1}}\) and at last \({\displaystyle dy/dw_{2}})\), while reverse accumulation has the traversal from outside to inside (first compute \({\displaystyle dy/dw_{2}}\) and then \({\displaystyle dw_{2}/dw_{1}}\) and at last \({\displaystyle dw_{1}/dx})\). More succinctly,

  • forward accumulation computes the recursive relation: \({\displaystyle {\frac {dw_{i}}{dx}}={\frac {dw_{i}}{dw_{i-1}}}{\frac {dw_{i-1}}{dx}}}\) with \({\displaystyle w_{3}=y}\), and,

  • reverse accumulation computes the recursive relation: \({\displaystyle {\frac {dy}{dw_{i}}}={\frac {dy}{dw_{i+1}}}{\frac {dw_{i+1}}{dw_{i}}}}\) with \({\displaystyle w_{0}=x}\).

We will use https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

from jax import grad, jacfwd
import jax.numpy as np

Now here are 3 lines of code for the propagation of uncertainty formula

(22)\[\begin{equation} \sigma_q = \sqrt{\left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2} \end{equation}\]
def error_prop_jax_gen(q,x,dx):
    jac = jacfwd(q)
    return np.sqrt(np.sum(np.power(jac(x)*dx,2)))

Setup two observations with uncertainties

Below I’ll use \(x\) and \(y\) for symbols, but they will be stored in the array x so that x[0]=\(x\) and `x[1]=\(y\).

x_ = np.array([2.,3.])
dx_ = np.array([.1,.1])
/Users/cranmer/anaconda3/envs/jax-md/lib/python3.6/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

Addition and Subtraction

In the special case of addition \(q(x,y) = x\pm y\) we have \(\sigma_q^2 = \sigma_x^2 + \sigma_y^2\).

def q(x):
    return x[0]+x[1]
def error_prop_classic(x, dx):
    # for q = x[0]*x[1]
    ret = dx[0]**2 + dx[1]**2
    return np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
q =  5.0 +/- 0.14142136
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
q =  5.0 +/- 0.14142136

Multiplication and Division

In the special case of multiplication

(23)\[\begin{equation} q(x,y) = x y \end{equation}\]

and division

(24)\[\begin{equation} q(x,y) = \frac{x}{y} \end{equation}\]
(25)\[\begin{equation} (\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2 \end{equation}\]

which we can rewrite as

(26)\[\begin{equation} \sigma_q = (x/y) \sqrt{\left(\frac{\sigma_x}{x}\right)^2 + \left(\frac{\sigma_y}{y}\right)^2} \end{equation}\]
def q(x):
    return x[0]*x[1]
def error_prop_classic(x, dx):
    # for q = x[0]*x[1]
    ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2 
    return (x[0]*x[1])*np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
q =  6.0 +/- 0.36055514
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
q =  6.0 +/- 0.36055514
def q(x):
    return x[0]/x[1]
def error_prop_classic(x, dx):
    # for q = x[0]*x[1]
    ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2 
    return (x[0]/x[1])*np.sqrt(ret)
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
q =  0.6666667 +/- 0.040061682
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
q =  0.6666667 +/- 0.040061682

Powers

\(q(x,y) = x^m y^n\) we have

(27)\[\begin{equation} (\sigma_q/q)^2 = \left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2 \end{equation}\]

which we can rewrite as

(28)\[\begin{equation} \sigma_q = x^m y^n \sqrt{\left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2} \end{equation}\]
def q(x, m=2, n=3):
    return np.power(x[0],m)*np.power(x[1],n)
x_ = np.array([1.5, 2.5])
dx_ = np.array([.1, .1])
q(x_)
DeviceArray(35.15625, dtype=float32)
def error_prop_classic(x, dx):
    # for q = x[0]*x[1]
    dq_ = q(x_)*np.sqrt(np.power(2*dx_[0]/x_[0],2)+np.power(3*dx_[1]/x_[1],2))
    return dq_
print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))
q =  35.15625 +/- 6.3063865
print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
q =  35.15625 +/- 6.3063865

Misc Examples

See some examples here:

http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html

Example: w = (4.52 ± 0.02) cm, A = (2.0 ± 0.2), y = (3.0 ± 0.6) cm. Find

(29)\[\begin{equation} z=\frac{wy^2}{\sqrt{A}} \end{equation}\]

The second relative error, (Dy/y), is multiplied by 2 because the power of y is 2.
The third relative error, (DA/A), is multiplied by 0.5 since a square root is a power of one half.

So Dz = 0.49 (28.638 ) = 14.03 which we round to 14

z = (29 ± 14) Using Eq. 3b, z=(29 ± 12) Because the uncertainty begins with a 1, we keep two significant figures and round the answer to match.

def q(x):
    return x[0]*x[2]*x[2]/np.sqrt(x[1])
x_ = np.array([4.52, 2., 3.]) #w,A,y
dx_ = np.array([.02, .2, .6])

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))
q =  28.765104 +/- 11.596283

Check with a plot

import numpy as onp  #using jax as np right now
w_ = onp.random.normal(x_[0], dx_[0], 10000)
A_ = onp.random.normal(x_[1], dx_[1], 10000)
y_ = onp.random.normal(x_[2], dx_[2], 10000)
x__ = np.vstack((w_, A_, y_))
z_ = q(x__)
print('mean =', np.mean(z_), 'std = ', np.std(z_))
mean = 30.050316 std =  11.813263
import matplotlib.pyplot as plt
_ = plt.hist(z_, bins=50)
../_images/error_propagation_with_jax_35_0.png

Example 2

also taken from http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html

w = (4.52 ± 0.02) cm, x = (2.0 ± 0.2) cm, y = (3.0 ± 0.6) cm.

Find

(30)\[\begin{equation} z = w x +y^2 \end{equation}\]

We have v = wx = (9.0 ± 0.9) cm.
The calculation of the uncertainty in is the same as that shown to the left. Then from Eq. 1b Dz = 3.7
z = (18 ± 4) .

def q(x):
    # [w,x,y]
    return x[0]*x[1]+x[2]*x[2]
x_ = np.array([4.52, 2., 3.]) #w,x,y
dx_ = np.array([.02, .2, .6])
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
18.04 +/- 3.711983

An example with many inputs

The code we used for error_prop_jax_gen is generic and supports functions q on any number of variables

def q(x):
    return np.sum(x)
x_ = 1.*np.arange(1,101) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(100)

The sum from \(1 to N\) is \(N*(N+1)/2\) (see the story of Gauss), so we expect q(x)=5050. And the uncertainty should be \(\sqrt{100}*0.1\) = 1.

print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
5050.0 +/- 1.0

another toy example… product from 1 to 10

def q(x):
    return np.product(x)
x_ = 1.*np.arange(1,11) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(10)
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))
3628800.0 +/- 451748.1

Checking this is an exercise left to the reader :-)