# Revisiting error propagation with automatic differentiation¶

By Kyle Cranmer, March 2, 2020

This notebook is dedicated to Feeman Dyson, who died on February 28, 2020 in Princeton, NJ at the age of 96.

“New directions in science are launched by new tools much more often than by new concepts. The effect of a concept-driven revolution is to explain old things in new ways. The effect of a tool-driven revolution is to discover new things that have to be explained.”

– Freeman Dyson

## Reminder of propagation of errors¶

This notebook was made to investigate the propagation of errors formula. We imagine that we have a function $$q(x,y)$$ and we want to propagate the uncertainty on $$x$$ and $$y$$ (denoted $$\sigma_x$$ and $$\sigma_y$$, respectively) through to the quantity $$q$$.

The most straight forward way to do this is just randomly sample $$x$$ and $$y$$, evaluate $$q$$ and look at it’s distribution. This is really the definition of what we mean by propagation of uncertianty. It’s very easy to do with some simply python code.

The calculus formula for the propagation of errors is really an approximation. This is the formula for a general $$q(x,y)$$

(21)$$$\sigma_q^2 = \left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2$$$

In the special case of addition $$q(x,y) = x\pm y$$ we have $$\sigma_q^2 = \sigma_x^2 + \sigma_y^2$$.

In the special case of multiplication $$q(x,y) = x y$$ and division $$q(x,y) = x / y$$ we have $$(\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2$$, which we can rewrite as $$\sigma_q = (x/y) \sqrt{(\sigma_x/x)^2 + (\sigma_y/y)^2}$$

Let’s try out these formulas and compare the direct approach of making the distribution to the prediction from these formulas

## Automatic Differentiation¶

Excerpts from the Wikipedia article: https://en.wikipedia.org/wiki/Automatic_differentiation

In mathematics and computer algebra, automatic differentiation (AD), also called algorithmic differentiation or computational differentiation,[1][2] is a set of techniques to numerically evaluate the derivative of a function specified by a computer program. AD exploits the fact that every computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (addition, subtraction, multiplication, division, etc.) and elementary functions (exp, log, sin, cos, etc.). By applying the chain rule repeatedly to these operations, derivatives of arbitrary order can be computed automatically, accurately to working precision, and using at most a small constant factor more arithmetic operations than the original program.

Usually, two distinct modes of AD are presented, forward accumulation (or forward mode) and reverse accumulation (or reverse mode). Forward accumulation specifies that one traverses the chain rule from inside to outside (that is, first compute $${\displaystyle dw_{1}/dx}$$ and then $${\displaystyle dw_{2}/dw_{1}}$$ and at last $${\displaystyle dy/dw_{2}})$$, while reverse accumulation has the traversal from outside to inside (first compute $${\displaystyle dy/dw_{2}}$$ and then $${\displaystyle dw_{2}/dw_{1}}$$ and at last $${\displaystyle dw_{1}/dx})$$. More succinctly,

• forward accumulation computes the recursive relation: $${\displaystyle {\frac {dw_{i}}{dx}}={\frac {dw_{i}}{dw_{i-1}}}{\frac {dw_{i-1}}{dx}}}$$ with $${\displaystyle w_{3}=y}$$, and,

• reverse accumulation computes the recursive relation: $${\displaystyle {\frac {dy}{dw_{i}}}={\frac {dy}{dw_{i+1}}}{\frac {dw_{i+1}}{dw_{i}}}}$$ with $${\displaystyle w_{0}=x}$$.

from jax import grad, jacfwd
import jax.numpy as np


Now here are 3 lines of code for the propagation of uncertainty formula

(22)$$$\sigma_q = \sqrt{\left( \frac{\partial q}{\partial x} \sigma_x \right)^2 + \left( \frac{\partial q}{\partial y}\sigma_y \right)^2}$$$
def error_prop_jax_gen(q,x,dx):
jac = jacfwd(q)
return np.sqrt(np.sum(np.power(jac(x)*dx,2)))


## Setup two observations with uncertainties¶

Below I’ll use $$x$$ and $$y$$ for symbols, but they will be stored in the array x so that x[0]=$$x$$ and x[1]=$$y$$.

x_ = np.array([2.,3.])
dx_ = np.array([.1,.1])

/Users/cranmer/anaconda3/envs/jax-md/lib/python3.6/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')


In the special case of addition $$q(x,y) = x\pm y$$ we have $$\sigma_q^2 = \sigma_x^2 + \sigma_y^2$$.

def q(x):
return x[0]+x[1]

def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = dx[0]**2 + dx[1]**2
return np.sqrt(ret)

print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))

q =  5.0 +/- 0.14142136

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))

q =  5.0 +/- 0.14142136


## Multiplication and Division¶

In the special case of multiplication

(23)$$$q(x,y) = x y$$$

and division

(24)$$$q(x,y) = \frac{x}{y}$$$
(25)$$$(\sigma_q/q)^2 = (\sigma_x/x)^2 + (\sigma_y/y)^2$$$

which we can rewrite as

(26)$$$\sigma_q = (x/y) \sqrt{\left(\frac{\sigma_x}{x}\right)^2 + \left(\frac{\sigma_y}{y}\right)^2}$$$
def q(x):
return x[0]*x[1]

def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2
return (x[0]*x[1])*np.sqrt(ret)

print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))

q =  6.0 +/- 0.36055514

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))

q =  6.0 +/- 0.36055514

def q(x):
return x[0]/x[1]

def error_prop_classic(x, dx):
# for q = x[0]*x[1]
ret = (dx[0]/x[0])**2 + (dx[1]/x[1])**2
return (x[0]/x[1])*np.sqrt(ret)

print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))

q =  0.6666667 +/- 0.040061682

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))

q =  0.6666667 +/- 0.040061682


## Powers¶

$$q(x,y) = x^m y^n$$ we have

(27)$$$(\sigma_q/q)^2 = \left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2$$$

which we can rewrite as

(28)$$$\sigma_q = x^m y^n \sqrt{\left(|m|\frac{\sigma_x}{x}\right)^2 + \left(|n|\frac{\sigma_y}{y}\right)^2}$$$
def q(x, m=2, n=3):
return np.power(x[0],m)*np.power(x[1],n)

x_ = np.array([1.5, 2.5])
dx_ = np.array([.1, .1])
q(x_)

DeviceArray(35.15625, dtype=float32)

def error_prop_classic(x, dx):
# for q = x[0]*x[1]
dq_ = q(x_)*np.sqrt(np.power(2*dx_[0]/x_[0],2)+np.power(3*dx_[1]/x_[1],2))
return dq_

print('q = ', q(x_), '+/-', error_prop_classic(x_, dx_))

q =  35.15625 +/- 6.3063865

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))

q =  35.15625 +/- 6.3063865


## Misc Examples¶

See some examples here:

http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html

Example: w = (4.52 ± 0.02) cm, A = (2.0 ± 0.2), y = (3.0 ± 0.6) cm. Find

(29)$$$z=\frac{wy^2}{\sqrt{A}}$$$

The second relative error, (Dy/y), is multiplied by 2 because the power of y is 2.
The third relative error, (DA/A), is multiplied by 0.5 since a square root is a power of one half.

So Dz = 0.49 (28.638 ) = 14.03 which we round to 14

z = (29 ± 14) Using Eq. 3b, z=(29 ± 12) Because the uncertainty begins with a 1, we keep two significant figures and round the answer to match.

def q(x):
return x[0]*x[2]*x[2]/np.sqrt(x[1])

x_ = np.array([4.52, 2., 3.]) #w,A,y
dx_ = np.array([.02, .2, .6])

print('q = ', q(x_), '+/-', error_prop_jax_gen(q, x_, dx_))

q =  28.765104 +/- 11.596283


### Check with a plot¶

import numpy as onp  #using jax as np right now
w_ = onp.random.normal(x_[0], dx_[0], 10000)
A_ = onp.random.normal(x_[1], dx_[1], 10000)
y_ = onp.random.normal(x_[2], dx_[2], 10000)
x__ = np.vstack((w_, A_, y_))
z_ = q(x__)
print('mean =', np.mean(z_), 'std = ', np.std(z_))

mean = 30.050316 std =  11.813263

import matplotlib.pyplot as plt

_ = plt.hist(z_, bins=50)


## Example 2¶

also taken from http://www.geol.lsu.edu/jlorenzo/geophysics/uncertainties/Uncertaintiespart2.html

w = (4.52 ± 0.02) cm, x = (2.0 ± 0.2) cm, y = (3.0 ± 0.6) cm. 

Find

(30)$$$z = w x +y^2$$$

We have v = wx = (9.0 ± 0.9) cm.
The calculation of the uncertainty in is the same as that shown to the left. Then from Eq. 1b Dz = 3.7
z = (18 ± 4) .

def q(x):
# [w,x,y]
return x[0]*x[1]+x[2]*x[2]

x_ = np.array([4.52, 2., 3.]) #w,x,y
dx_ = np.array([.02, .2, .6])

print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))

18.04 +/- 3.711983


## An example with many inputs¶

The code we used for error_prop_jax_gen is generic and supports functions q on any number of variables

def q(x):
return np.sum(x)

x_ = 1.*np.arange(1,101) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(100)


The sum from $$1 to N$$ is $$N*(N+1)/2$$ (see the story of Gauss), so we expect q(x)=5050. And the uncertainty should be $$\sqrt{100}*0.1$$ = 1.

print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))

5050.0 +/- 1.0


another toy example… product from 1 to 10

def q(x):
return np.product(x)

x_ = 1.*np.arange(1,11) #counts from 1-100 (and 1.* to make them floats)
dx_ = 0.1*np.ones(10)
print(q(x_),'+/-', error_prop_jax_gen(q, x_, dx_))

3628800.0 +/- 451748.1
`

Checking this is an exercise left to the reader :-)