Visualizing joint and marginal distributions

Let’s explore a dataset that has 4 continuous random variables and one discrete random variable. We will visualize various 1-D and 2-D marginal distributions. The 2-d distributions are joint distributions in the context of the two random variables being plotted, but are marginal distributions in the context of the full dataset.

import seaborn as sns
sns.__version__ # check version, need 0.11.0 or greater
'0.11.0'
penguins = sns.load_dataset("penguins")
penguins
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female
... ... ... ... ... ... ... ...
339 Gentoo Biscoe NaN NaN NaN NaN NaN
340 Gentoo Biscoe 46.8 14.3 215.0 4850.0 Female
341 Gentoo Biscoe 50.4 15.7 222.0 5750.0 Male
342 Gentoo Biscoe 45.2 14.8 212.0 5200.0 Female
343 Gentoo Biscoe 49.9 16.1 213.0 5400.0 Male

344 rows × 7 columns

First we can visualize the raw data both as a scatter plot of xy-pairs as well as a “rug plot” showing ticks for x and y individually.

# Use JointGrid directly to draw a custom plot
g = sns.JointGrid(data=penguins, x="bill_length_mm", y="bill_depth_mm", space=0, ratio=15)
g.plot_joint(sns.scatterplot,  alpha=.6, legend=False)
g.plot_marginals(sns.rugplot, height=1,  alpha=.6)
<seaborn.axisgrid.JointGrid at 0x7fe81081f6a0>
../_images/visualize_marginals_4_1.png

We can make a similar kind of plot, where instead of visualizing the raw data, we use a histogram to approximate the parent distribution both for the joint and for the marginals

sns.jointplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", kind="hist")
<seaborn.axisgrid.JointGrid at 0x7fe8320b53a0>
../_images/visualize_marginals_6_1.png

Note the marginal distribution displayed on the top is the same as just a simple histogram of bill_length_mm. When we histogram a variable \(X_i\), we implicitly are marginalizing over all the other variables in the dataset.

sns.displot(data=penguins, x="bill_length_mm", kind="hist")
<seaborn.axisgrid.FacetGrid at 0x7fe810874be0>
../_images/visualize_marginals_8_1.png

Or we can use a kernel density estimation (kde) technique to get a smoother estimate

sns.jointplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", kind="kde", fill=True)
<seaborn.axisgrid.JointGrid at 0x7fe83243f460>
../_images/visualize_marginals_10_1.png

Grouping by class

We can also look at bill_length_mm vs. bill_depth_mm grouped by species. Before we plot, let’s look at the correlation matrix for these classes. Note that the two variables are positivley correlated in each case.

penguins[["bill_length_mm","bill_depth_mm","species"]].groupby("species").corr()
bill_length_mm bill_depth_mm
species
Adelie bill_length_mm 1.000000 0.391492
bill_depth_mm 0.391492 1.000000
Chinstrap bill_length_mm 1.000000 0.653536
bill_depth_mm 0.653536 1.000000
Gentoo bill_length_mm 1.000000 0.643384
bill_depth_mm 0.643384 1.000000
sns.jointplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", kind="kde", hue="species")
<seaborn.axisgrid.JointGrid at 0x7fe8325bea60>
../_images/visualize_marginals_13_1.png

The plot above clearly reveals the origin of the multi-modal structure of the underlying data. Well it doesn’t indicate that it causes it, but species does seem to naturally map onto the different modes.

Visualizing multivariate data

Often we are working with higher dimensional data that is not so easy to visualize. In this penguin dataset, there are four continuous random variables in addition to the categorical species label. We can print the 4x4 correlation matrix grouped by species again:

penguins.groupby("species").corr()
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
species
Adelie bill_length_mm 1.000000 0.391492 0.325785 0.548866
bill_depth_mm 0.391492 1.000000 0.307620 0.576138
flipper_length_mm 0.325785 0.307620 1.000000 0.468202
body_mass_g 0.548866 0.576138 0.468202 1.000000
Chinstrap bill_length_mm 1.000000 0.653536 0.471607 0.513638
bill_depth_mm 0.653536 1.000000 0.580143 0.604498
flipper_length_mm 0.471607 0.580143 1.000000 0.641559
body_mass_g 0.513638 0.604498 0.641559 1.000000
Gentoo bill_length_mm 1.000000 0.643384 0.661162 0.669166
bill_depth_mm 0.643384 1.000000 0.706563 0.719085
flipper_length_mm 0.661162 0.706563 1.000000 0.702667
body_mass_g 0.669166 0.719085 0.702667 1.000000

But we can also make what is called a pairplot that makes a grid of plots. On the off-diagonal one either plots a scatter plot of \(X_i\) vs. \(X_j\) or an estimate of \(p(X_i,X_j)\) – note \(p(X_i,X_j)\) is 2-d, but it is actually marginalizing over all the other random variables. Along the diagonal \(X_i=X_j\), trying to visualize a 2-D distribution doesn’t make much sense. Instead, along the diagonal one usually plots the univariate marginal \(p(X_i)\).

sns.pairplot(penguins, kind="kde")
<seaborn.axisgrid.PairGrid at 0x7fe8603e2490>
../_images/visualize_marginals_18_1.png

And we can do both a pairplot grouped by species, which reveals a lot about this dataset.

sns.pairplot(penguins, hue="species");
../_images/visualize_marginals_20_0.png

Note there were problems with seaborn 0.10.1