import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
Lecture 5 - Visualization
Overview
We provide general principles for visualization and cover the basics of plotting in python with Seaborn and Matplotlib.
References
This lecture contains material from:
- Seaborn documentation
- Chapter 4, Data Science: A First Introduction with Python (Timbers et al. 2022)
- Chapter 9 Python for Data Analysis, 3E (Wes McKinney, 2022)
- STAT545, UBC
- Leek, J. and Peng, R. What is the question? Science, 347(6228):1314–1315, 2015.
- Gelman, A., Pasarica, C. and Dodhia R. 2002. “Let’s Practice What We Preach: Turning Tables into Graphs.” The American Statistician 56 (2).
Introduction
Asking questions about data
A visualization helps answer a question about a dataset.
A good visualization will clearly answer your question without distraction; a great visualization can illustrate what the question was without any text explanation.
Types of questions (Leek and Peng, 2015)
From Data Science: A First Introduction with Python:
- Descriptive: a question that asks about summaries of a dataset without interpretation
- Exploratory: a question that asks if there are patterns or trends in a single dataset (often used to develop hypotheses)
- Predictive: a question that asks about predicting measurements for individuals
- Inferential: a question that looks for patterns or trends and tests for applicability in a wider population
- Causal: a question that asks whether changing one variable will lead to a change in another variable, on average in a population
- Mechanistic: a quesetion that asks about the underlying mechanism
Generally we use visualization to help answer descriptive and exploratory questions.
Examples of exploratory questions:
- What is the data quality? Are there inconsistencies or illogical values?
- What are the distributions of variables?
- What associations do there seem to be between variables?
From Gelman et al. (2002): visualizations are most helpful to:
- facilitate comparisons
- reveal trends
A famous graph
John Snow’s cholera map (1855):
Cholera is an intestinal disease that can cause death within hours of onset of vomiting
In 1854, there was a cholera epidemic in Soho, London
John Snow, an obstetrician, believed that contaminated water wells were the source of the epidemic
This is the famous map he created which showed cholera cases (black rectangles) clustered around the well at intersection of Broad and Cambridge Streets.
Source: wikipedia
Visualization principles
Convey the message
- make sure visualization answers the question as simply and plainly as possible
- use legends and labels so visualization is understandable without reading surrounding text
- ensure data is clearly visable
- use color schemes understandable by those with color blindness
- choose scale appropriately (e.g. log scale)
Minimize noise
- avoid chartjunk! e.g. distracting grid lines, gratuitous use of icons, unnecessary 3D
- use colors sparingly - too many colors can be distracting
- if your plot has too many dots or lines and starts to look messy, you need to do something different!
- don’t adjust axes to zoom in on small differences. If the difference is small, show that it is small!
Figure 24.1: From Data Looks Better Naked by Darkhorse Analytics
No no’s
- No pie charts!! (They encode quantitative information in angles and areas, which are hard to judge)
- No stacked bar charts!
Examples of misleading graphs.
Choice of visualization
Summary of basic visualizations (from Data Science: A First Introduction with Python)
- scatter plots visualize the relationship between two quantitative variables
- line plots visualize trends with respect to an independent, ordered quantity (e.g., time)
- bar plots visualize comparisons of amounts
- histograms visualize the distribution of one quantitative variable (i.e., all its possible values and how often they occur)
Import packages
Let’s now import some packages.
We will also use statsmodels
(install this package in your msds597
environment).
import statsmodels
Seaborn Relplot
The sns.relplot()
function visualizes statistical relationships using either scatter plots or line plots.
The default is a scatter plot. For line plots, use the argument kind='line'
.
Let’s first look at the tips
data, available in Seaborn.
= sns.load_dataset("tips")
tips tips.head()
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
= sns.relplot(data=tips,
g ="total_bill",
x="tip")
y5, 3) g.figure.set_size_inches(
=tips,
sns.relplot(data="total_bill",
x="tip",
y="time",
hue=1.5,
aspect=3.5) height
sns.relplot(=tips,
data="total_bill",
x="tip",
y="smoker",
hue="time",
style=1.5,
aspect=3.5
height )
sns.relplot(=tips,
data="total_bill",
x="tip",
y="size",
hue=1.5,
aspect=3.5
height )
=tips,
sns.relplot(data="total_bill",
x="tip",
y='time',
hue="size",
size=1.5,
aspect=3.5) height
=tips,
sns.relplot(data="total_bill",
x="tip",
y='time', # colors
hue='day', # markers
style="size", # size of the points
size='sex', # subplots in columns
col='scatter', # lineplot or scatterplot
kind='smoker',
row=2.5,
height=1.5) # subplots in the rows aspect
=tips,
sns.relplot(data="total_bill",
x="tip",
y='time',
hue='sex',
style="size",
size=1.5,
aspect=3.5) height
= sns.load_dataset("dowjones")
dowjones dowjones.head()
Date | Price | |
---|---|---|
0 | 1914-12-01 | 55.00 |
1 | 1915-01-01 | 56.55 |
2 | 1915-02-01 | 56.00 |
3 | 1915-03-01 | 58.30 |
4 | 1915-04-01 | 66.45 |
=dowjones,
sns.relplot(data="Date",
x="Price",
y="line",
kind=3,
height=3) aspect
= sns.load_dataset("fmri")
fmri fmri.head()
subject | timepoint | event | region | signal | |
---|---|---|---|---|---|
0 | s13 | 18 | stim | parietal | -0.017552 |
1 | s5 | 14 | stim | parietal | -0.080883 |
2 | s12 | 18 | stim | parietal | -0.081033 |
3 | s11 | 18 | stim | parietal | -0.046134 |
4 | s10 | 18 | stim | parietal | -0.037970 |
=fmri,
sns.relplot(data="timepoint",
x="signal",
y="line",
kind=2,
aspect=3) height
sns.relplot(=fmri,
data="line",
kind="timepoint",
x="signal",
y="event",
hue=3,
height=2
aspect )
sns.relplot(=fmri,
data="line",
kind="timepoint",
x="signal",
y="region",
hue="event",
style=3,
height=2
aspect )
sns.relplot(=fmri, kind="line",
data="timepoint", y="signal", hue="region", style="event",
x=False, markers=True,
dashes=3,
height=2
aspect )
sns.relplot(=fmri,
data="line",
kind="timepoint",
x="signal",
y="region",
row="event",
col=3,
height=1.5
aspect )
sns.relplot(=fmri, kind="line",
data="timepoint", y="signal", hue="subject",
x="region", row="event", height=3) col
Linear regression
Seaborn also has an lmplot
function which plots data and a regression model. Let’s explore this function using the mpg
dataset.
This dataset contains a subset of the fuel economy data that the EPA makes available on http://fueleconomy.gov. It contains only models which had a new release every year between 1999 and 2008 - this was used as a proxy for the popularity of the car.
- Format
A data frame with 234 rows and 11 variables:
manufacturer
: manufacturer namemodel
: model namedispl
: engine displacement, in litresyear
: year of manufacturecyl
: number of cylinderstrans
: type of transmissiondrv
: the type of drive train, where f = front-wheel drive, r = rear wheel drive, 4 = 4wdcty
:city miles per gallonhwy
: highway miles per gallonfl
: fuel typeclass
: “type” of car
= pd.read_csv('../data/mpg.csv') mpg
mpg.head()
manufacturer | model | displ | year | cyl | trans | drv | cty | hwy | fl | class | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | audi | a4 | 1.8 | 1999 | 4 | auto(l5) | f | 18 | 29 | p | compact |
1 | audi | a4 | 1.8 | 1999 | 4 | manual(m5) | f | 21 | 29 | p | compact |
2 | audi | a4 | 2.0 | 2008 | 4 | manual(m6) | f | 20 | 31 | p | compact |
3 | audi | a4 | 2.0 | 2008 | 4 | auto(av) | f | 21 | 30 | p | compact |
4 | audi | a4 | 2.8 | 1999 | 6 | auto(l5) | f | 16 | 26 | p | compact |
Some questions we may ask:
- do cars with big engines use more fuel than cars with small engines?
= sns.relplot(mpg, x='displ', y='hwy', hue='class')
g 'Engine Displacement (L)','Highway miles per gallon') g.set_axis_labels(
= sns.lmplot(mpg, x='displ', y='hwy')
g 'Engine Displacement (L)','Highway miles per gallon') g.set_axis_labels(
= sns.lmplot(mpg, x='displ', y='hwy', hue='drv')
g 'Engine Displacement (L)','Highway miles per gallon') g.set_axis_labels(
We can also use an order-2 regression model. This computes the regression coefficients for the model:
y_i = \beta_0 + \beta_1 x_i + \beta_2 x_i^2 + \varepsilon_i
='displ', y='hwy', order=2)
sns.regplot(mpg, x'Engine Displacement (L)')
plt.gca().set_xlabel('Highway miles per gallon') plt.gca().set_ylabel(
Text(0, 0.5, 'Highway miles per gallon')
We can also use locally weighted scatterplot smoothing.
From statsmodels:
Suppose the input data has N
points. The algorithm works by estimating y_i
by taking the frac*N
closest points to (x_i,y_i)
based on their x
values and estimating y_i
using a weighted linear regression. The weight for (x_j,y_j)
is tricube function applied to abs(x_i-x_j)
.
='displ', y='hwy', lowess=True)
sns.regplot(mpg, x'Engine Displacement (L)')
plt.gca().set_xlabel('Highway miles per gallon') plt.gca().set_ylabel(
Text(0, 0.5, 'Highway miles per gallon')
Seaborn Displot
Seaborn sns.displot
contains functions to visualize the distribution of data.
= sns.load_dataset("penguins")
penguins penguins.head()
species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | |
---|---|---|---|---|---|---|---|
0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | Male |
1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | Female |
2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | Female |
3 | Adelie | Torgersen | NaN | NaN | NaN | NaN | NaN |
4 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | Female |
Histograms
One of the most common approaches to visualizing a distribution is a histogram.
In a histogram, the x-axis (corresponding to the data variable) is divided into a set of bins and the count of observations falling within each bin is shown using the height of the corresponding bar.
sns.displot(penguins,="flipper_length_mm",
x=3,
height=1.5) aspect
We can choose the bin size:
sns.displot(penguins,="flipper_length_mm",
x=7.1,
binwidth=3,
height=1.5) aspect
Or specify the number of bins we want:
sns.displot(penguins,="flipper_length_mm",
x=20,
bins=3,
height=1.5
aspect )
="flipper_length_mm",
sns.displot(penguins, x=0.3,
binwidth=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x=30,
binwidth=3,
height=1.5) # binwdith too big, the two hills in the data are not visible aspect
sns.displot(penguins,="flipper_length_mm",
x=15,
bins=3,
height=1.5) aspect
= sns.load_dataset("tips")
tips tips.head()
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
sns.displot(tips,="size",
x=True,
discrete=3,
height=1.5) aspect
sns.displot(tips,="day",
x=3,
height=1.5)
aspect# no need to specify discrete=True beacuse seaborn figures it out on its own
sns.displot(penguins,="flipper_length_mm",
x="species",
hue=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x="species",
hue='island',
col=4) height
sns.displot(penguins,="flipper_length_mm",
x="species",
hue="dodge",
multiple=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x="sex",
col=3,
height=1.5) aspect
Kernel density estimation
Rather than using discrete bins, a kernel density estimate plot smooths the observations with a kernel.
\widehat{f}_h(x) = \frac{1}{nh}\sum_{i=1}^n K\left(\frac{x - x_i}{h}\right).
The default is to use a Gaussian kernel: \widehat{f}_h(x) = \frac{1}{nh\sigma}\frac{1}{\sqrt{2\pi}}\sum_{i=1}^n \exp\left(-\frac{(x - x_i)^2}{2h^2\sigma^2}\right)
where \sigma is the standard deviation of the sample \{x_1,\dots, x_n\}.
sns.displot(penguins,="flipper_length_mm",
x="kde",
kind=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x="kde",
kind=0.05,
bw_method=3,
height=1.5) # setting the bandwidth
aspect# overfitting
# curve is jittery and the jitter is from noise, bandwidth is too small
sns.displot(penguins,="flipper_length_mm",
x="kde",
kind=0.3,
bw_method=3,
height=1.5) # setting the bandwidth aspect
sns.displot(penguins,="flipper_length_mm",
x="kde",
kind=2,
bw_method=3,
height=1.5) # setting the bandwidth
aspect# underfitting:
# bandwidth too big, curve too smoothed out, not informative
sns.displot(penguins,="flipper_length_mm",
x="species",
hue="kde",
kind=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x="species",
hue='island',
col="kde",
kind=3,
height=1.5) aspect
sns.displot(penguins,="flipper_length_mm",
x="species",
hue="kde",
kind=True,
fill=3,
height=1.5) aspect
Bivariate distributions
# bivariate histogram
sns.displot(penguins,="bill_length_mm",
x="bill_depth_mm",
y=4) height
sns.displot(penguins,="bill_length_mm",
x="bill_depth_mm",
y=True,
cbar=3,
height=1.25) # adding a colorbar aspect
sns.displot(penguins,="bill_length_mm",
x="bill_depth_mm",
y='species',
hue='hist',
kind=3) # default is hist height
sns.displot(penguins,="bill_length_mm",
x="bill_depth_mm",
y='kde',
kind='species',
hue=3) height
sns.displot(penguins,="bill_length_mm",
x="bill_depth_mm",
y="species",
hue='island',
col="kde",
kind=3) height
Plotting joint and marginal distributions
=penguins,
sns.jointplot(data="bill_length_mm",
x="bill_depth_mm",
y='species',
hue=4) height
=penguins,
sns.jointplot(data="bill_length_mm",
x="bill_depth_mm",
y='hist',
kind=4) height
=penguins,
sns.jointplot(data="bill_length_mm",
x="bill_depth_mm",
y='species',
hue='kde',
kind=4) height
= sns.jointplot(data=penguins,
g ="bill_length_mm",
x="bill_depth_mm",
y=4)
heighttype(g)
seaborn.axisgrid.JointGrid
jointplot()
is an interface to the JointGrid
class, which has helpful functions like plot_joint
and plot_marginal
.
= sns.jointplot(data=penguins,
g ="bill_length_mm",
x="bill_depth_mm",
y=4)
height
g.plot_joint(sns.kdeplot,="red") color
# scatter plot in blue
= sns.jointplot(data=penguins,
g ="bill_length_mm",
x="bill_depth_mm",
y=4)
height
# kde plot in red, same plot
g.plot_joint(sns.kdeplot,="red")
color
# rug plot in green
g.plot_marginals(sns.rugplot,="green", height=0.15) color
Multiple variables
='species',
sns.pairplot(penguins, hue=2.5) height
Seaborn Catplot
catplot
is similar to relplot
, but designed specifically for categorical data.
Generally speaking, there are three different families of categorical plots in Seaborn:
Categorical scatterplots:
stripplot()
(withkind="strip"
; the default)swarmplot()
(withkind="swarm"
)
Categorical distribution plots:
boxplot()
(withkind="box"
)violinplot()
(withkind="violin"
)
Categorical estimate plots:
pointplot()
(withkind="point"
)barplot()
(withkind="bar"
)countplot()
(withkind="count"
)
Categorical scatterplots
=tips,
sns.catplot(data="day",
x="tip",
y# kind='strip' # default is 'strip'
=False, # default is True
jitter=4.5) height
=tips,
sns.catplot(data="day",
x="tip",
y="swarm",
kind=4.5) height
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:
9.7% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:
8.0% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.
=tips,
sns.catplot(data="day",
x="tip",
y="time",
hue="swarm",
kind=4.5) height
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:
9.7% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:
8.0% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:
6.5% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.
We can use different color palettes:
https://seaborn.pydata.org/tutorial/color_palettes.html
=tips,
sns.catplot(data="day",
x="total_bill",
y="size",
hue='sex',
col=sns.color_palette("Blues"),
palette=4) height
=tips,
sns.catplot(data="total_bill",
x="day",
y="time",
hue='sex',
col=4) height
Categorical distribution plots
=tips,
sns.catplot(data="day",
x="total_bill",
y="box",
kind=4) height
=tips,
sns.catplot(data="day",
x="total_bill",
y="smoker",
hue="box",
kind=4) height
sns.catplot(=tips,
data="day",
x="total_bill",
y="sex",
hue='violin'
kind )
sns.catplot(=tips,
data="day",
x="total_bill",
y="sex",
hue='time',
col="violin",
kind=True,
split )
sns.catplot(=tips,
data="day",
x="total_bill",
y="sex",
hue="violin",
kind='stick',
inner=True,
split )
sns.catplot(=tips,
data="day",
x="total_bill",
y="sex",
col="violin",
kind='stick',
inner=True,
split )
Categorical estimate plots
Barplots take the mean by default.
=tips,
sns.catplot(data="day",
x="total_bill",
y="sex",
hue# errorbar="ci" - default - uses bootstrapping to compute a confidence interval around the estimate
="bar",
kind=3) height
'sex', 'day']))['total_bill'].agg(['mean', 'std']) (tips.groupby([
/var/folders/f0/m7l23y8s7p3_0x04b3td9nyjr2hyc8/T/ipykernel_4800/191020303.py:1: FutureWarning:
The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
mean | std | ||
---|---|---|---|
sex | day | ||
Male | Thur | 18.714667 | 8.019728 |
Fri | 19.857000 | 10.015847 | |
Sat | 20.802542 | 9.836306 | |
Sun | 21.887241 | 9.129142 | |
Female | Thur | 16.715312 | 7.759764 |
Fri | 14.145556 | 4.788547 | |
Sat | 19.680357 | 8.806470 | |
Sun | 19.872222 | 7.837513 |
=tips,
sns.catplot(data="day",
x="total_bill",
y="sex",
hue="bar",
kind='sd',
errorbar=3) # interval is +/- 1 sd around the estimate height
Here are some more details about the kinds of error bars available in Seaborn.
sns.catplot(=tips,
data="day",
x="sex",
hue="count", # no calculating mean, just count
kind=4
height )
=tips,
sns.catplot(data="day",
x="tip",
y="sex",
hue="point",
kind=['<', 'o'],
markers=3,
height=1.5) aspect
Let’s see if we can do something more informative.
'perc'] = tips['tip'] / tips['total_bill']
tips[= tips[['sex', 'day', 'perc']]
tips_subset = tips_subset.groupby('day')['perc'].mean()
tips_mean = pd.DataFrame(tips_mean)
tips_mean = ['mean']
tips_mean.columns = tips_subset.merge(tips_mean, on='day')
tips_subset 'diff'] = tips_subset['perc'] - tips_subset['mean'] tips_subset[
/var/folders/f0/m7l23y8s7p3_0x04b3td9nyjr2hyc8/T/ipykernel_4800/3189921240.py:3: FutureWarning:
The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
tips_subset
sex | day | perc | mean | diff | |
---|---|---|---|---|---|
0 | Female | Sun | 0.059447 | 0.166897 | -0.107451 |
1 | Male | Sun | 0.160542 | 0.166897 | -0.006356 |
2 | Male | Sun | 0.166587 | 0.166897 | -0.000310 |
3 | Male | Sun | 0.139780 | 0.166897 | -0.027117 |
4 | Female | Sun | 0.146808 | 0.166897 | -0.020090 |
... | ... | ... | ... | ... | ... |
239 | Male | Sat | 0.203927 | 0.153152 | 0.050775 |
240 | Female | Sat | 0.073584 | 0.153152 | -0.079568 |
241 | Male | Sat | 0.088222 | 0.153152 | -0.064929 |
242 | Male | Sat | 0.098204 | 0.153152 | -0.054947 |
243 | Female | Thur | 0.159744 | 0.161276 | -0.001531 |
244 rows × 5 columns
='day', y='diff', hue='sex', kind='box')
sns.catplot(tips_subset, x=0, color='darkgrey', linestyle='dotted', alpha=0.7) plt.axhline(y
Another categorical example
= sns.load_dataset("car_crashes")
crashes = crashes.sort_values("total", ascending=False) crashes
crashes.head()
total | speeding | alcohol | not_distracted | no_previous | ins_premium | ins_losses | abbrev | |
---|---|---|---|---|---|---|---|---|
40 | 23.9 | 9.082 | 9.799 | 22.944 | 19.359 | 858.97 | 116.29 | SC |
34 | 23.9 | 5.497 | 10.038 | 23.661 | 20.554 | 688.75 | 109.72 | ND |
48 | 23.8 | 8.092 | 6.664 | 23.086 | 20.706 | 992.61 | 152.56 | WV |
3 | 22.4 | 4.032 | 5.824 | 21.056 | 21.280 | 827.34 | 142.39 | AR |
17 | 21.4 | 4.066 | 4.922 | 16.692 | 16.264 | 872.51 | 137.13 | KY |
='total',
sns.stripplot(crashes, x='abbrev')
yFalse)
plt.gca().xaxis.grid(True)
plt.gca().yaxis.grid(5, 10)
plt.gcf().set_size_inches(set(xlim=(0, 25), title='Total crashes', xlabel="", ylabel="") plt.gca().
[(0.0, 25.0),
Text(0.5, 1.0, 'Total crashes'),
Text(0.5, 0, ''),
Text(0, 0.5, '')]
Heatmaps
= pd.read_csv('../data/rates.csv')
rates = pd.to_datetime(rates.Time) rates.Time
= rates.corr(numeric_only=True) corr_mat
corr_mat
USD | JPY | BGN | CZK | DKK | GBP | CHF | |
---|---|---|---|---|---|---|---|
USD | 1.000000 | -0.103337 | NaN | -0.218888 | -0.232007 | 0.074199 | -0.042449 |
JPY | -0.103337 | 1.000000 | NaN | 0.655093 | 0.463510 | 0.484794 | 0.901636 |
BGN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
CZK | -0.218888 | 0.655093 | NaN | 1.000000 | 0.008358 | 0.128065 | 0.649767 |
DKK | -0.232007 | 0.463510 | NaN | 0.008358 | 1.000000 | 0.307508 | 0.604572 |
GBP | 0.074199 | 0.484794 | NaN | 0.128065 | 0.307508 | 1.000000 | 0.424830 |
CHF | -0.042449 | 0.901636 | NaN | 0.649767 | 0.604572 | 0.424830 | 1.000000 |
= corr_mat.drop(index='BGN', columns='BGN') corr_mat
sns.heatmap(corr_mat)
For continuous scales, we can use the argument cmap
(before, we used palette
for discrete colors).
='RdBu', center = 0) sns.heatmap(corr_mat, cmap
='RdBu_r', center = 0, annot=True) sns.heatmap(corr_mat, cmap
Clustermap
# Load the brain networks example dataset
= sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) df
# Select a subset of the networks
= [1, 5, 6, 7, 8, 12, 13, 17]
used_networks = (df.columns.get_level_values("network")
used_columns int)
.astype(
.isin(used_networks))= df.loc[:, used_columns] df
='vlag', center = 0) sns.heatmap(df.corr(), cmap
We can plot clusters using sns.clustermap
, which plots a matrix dataset as a hierarchically-clustered heatmap.
Hierarchical clustering: 1. treats every data point as its own cluster 2. forms one cluster by combining the two closest clusters 3. repeat step 2. until there is one big cluster
sns.clustermap
uses scipy
for hierarchical clustering, details here.
# get color palette
= sns.husl_palette(len(used_networks), s=.45)
network_pal # return hues with constant lightness and saturation in the HUSL system.
# s: saturation intensity
= dict(zip(map(str, used_networks), network_pal)) network_lut
"network") df.columns.get_level_values(
Index(['1', '1', '5', '5', '6', '6', '6', '6', '7', '7', '7', '7', '7', '7',
'8', '8', '8', '8', '8', '8', '12', '12', '12', '12', '12', '13', '13',
'13', '13', '13', '13', '17', '17', '17', '17', '17', '17', '17'],
dtype='object', name='network')
# Convert the palette to vectors that will be drawn on the side of the matrix
= df.columns.get_level_values("network")
networks = pd.Series(networks, index=df.columns).map(network_lut) network_colors
= sns.clustermap(df.corr(), center=0, cmap="vlag",
g =network_colors, col_colors=network_colors,
row_colors= False,
row_cluster =(.1, .2), # how much of plot should be dendrogram
dendrogram_ratio=(.02, .32, .03, .2), # (left, bottom, width, height)
cbar_pos=.75, # space between squares
linewidths=(12, 13)) figsize
Colors
Matplotlib master class
So far, we have used Seaborn functions, which are built on matplotlib
.
Full customization of Seaborn plots requires some knowledge of matplotlib
concepts.
We’ll start with some basic matplotlib
functions.
= np.linspace(-4, 4, num=50)
data1_x
= np.cos(data1_x)
data1_y1 = np.sin(data1_x) data1_y2
plt.plot(data1_x, data1_y1)
plt.plot(data1_x, data1_y2)# plt.show # when matplotlib is used in a terminal or .py script, need this
https://matplotlib.org/stable/api/markers_api.html
="X") plt.plot(data1_x, data1_y1, marker
plt.scatter(data1_x, data1_y1) plt.plot(data1_x, data1_y1)
- Call
plt.legend()
to show labels - Call
plt.xlabel(my_label)
to add a label to the x-axis - Call
plt.ylabel(my_label)
to add a label to the y-axis - Call
plt.title(my_title)
to add a title to the plot
='cosine')
plt.plot(data1_x, data1_y1, label='sine')
plt.plot(data1_x, data1_y2, label plt.legend()
Other useful things
plt.legend(loc="lower left", ncol=2)
: lower|center|upper + left|center|rightplt.xlim(left, right)
plt.ylim(bottom, top)
='cosine')
plt.plot(data1_x, data1_y1, label='sine')
plt.plot(data1_x, data1_y2, label="lower right", ncol=2) plt.legend(loc
Details
When we call plt.plot()
, Matplotlib implicitly creates:
- a Figure
- Axes
A Figure is a top-level container that can hold one or more plots. It is like the blank canvas for your plots. It can contain: multiple subplots, colorbars, legends, titles.
An Axes is the plot where the data is displayed, containing the data points, tick marks, grid lines.
To have more control over our plots, we can specify the Figure and Axes directly.
# create a figure object: the container of the plot, and with an axis on it
= plt.subplots()
figure, ax
ax.plot(data1_x, data1_y1) ax.plot(data1_x, data1_y2)
Anatomy of a matplotlib figure
Credits: https://matplotlib.org/stable/gallery/showcase/anatomy.html#anatomy-of-a-figure
Subplots
Credits: https://matplotlib.org/3.1.0/gallery/subplots_axes_and_figures/subplots_demo.html
= plt.subplots(
figure, axes =2, ncols=2,
nrows=(7,4),
figsize={'width_ratios': [1, 2]} # specifies width ratios
gridspec_kw
)
0,0].plot(data1_x, data1_y1, label='cosine')
axes[1,0].plot(data1_x, data1_y2, label='sine')
axes[
0,0].legend()
axes[1,0].legend()
axes[
# all the plt.xlabel become ax.set_xlabel ...
0,0].set_xlabel("x")
axes[1,0].set_xlabel("x")
axes[
0,0].set_ylabel("y")
axes[1,0].set_ylabel("$y_1$")
axes[
0,0].set_title("Cosine Plot")
axes[1,0].set_title("Sine Plot")
axes[
0,1].hist(np.random.normal(size=(500)))
axes[1,1].hist(np.random.beta(0.5, 0.5, size=(500)))
axes[
0,1].set_ylabel('Density')
axes[
"The suptitle") # suptitle stands for "super-title" - it describes ALL plots in a figure.
figure.suptitle(
# This automatically adjusts padding to fit plots in figure without overlap
figure.tight_layout()
= plt.subplots(
figure, axes =1, ncols=2,
nrows=(5.5,2)
figsize
)
0].plot(data1_x, data1_y1, linestyle=(0, (3, 5, 1, 5)), color='green')
axes[1].plot(data1_x, data1_y2)
axes[
= figure.add_axes((0.7, 0.7, 0.2,0.2)) # coordinates from bottom left, then size
extra_ax extra_ax.plot(data1_x, data1_y1)
See more linestyles here.
Subplot Mosaic:
plt.subplot_mosaic
takes two lists, each list representing a row, and each element in the list a key representing the column.
(Adapted from https://matplotlib.org/stable/users/explain/axes/arranging_axes.html#basic-2x2-grid)
= plt.subplot_mosaic([['upper left', 'right'],
fig, axd 'lower left', 'right']],
[=(5.5, 3.5), layout="constrained")
figsizefor k, ax in axd.items():
0.5, 0.5,
ax.text(f'axd[{k!r}]', # !r adds quotations
="center", va="center",
ha=14, color="darkgrey")
fontsize
'plt.subplot_mosaic()') fig.suptitle(
Text(0.5, 0.98, 'plt.subplot_mosaic()')
= plt.subplot_mosaic([['upper left', 'right'],
fig, axd 'lower left', 'right']],
[=(6, 4), layout="constrained")
figsize
=axd['upper left'])
rates.USD.plot(ax=axd['right'])
rates.USD.hist(ax='index', y='USD', ax=axd['lower left']) rates.reset_index().plot.scatter(x
Adding multiple lines
Example from https://wesmckinney.com/book/plotting-and-visualization
= plt.subplots(figsize=(8, 4))
fig, ax
1000).cumsum(), color="black", label="one")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dashed",label="two")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dotted",label="three")
ax.plot(np.random.randn( ax.legend()
Adding text
= plt.subplots(figsize=(6, 3))
fig, ax
1000).cumsum(), color="black", label="one")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dashed",label="two")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dotted",label="three")
ax.plot(np.random.randn(
ax.legend()
500, 0, "Hello world!",
ax.text(="monospace", fontsize=10) family
Text(500, 0, 'Hello world!')
Back to mpg
data:
= plt.subplots(figsize=(6, 4))
fig, ax
ax.scatter(mpg.displ, mpg.hwy)'Engine Displacement (L)')
ax.set_xlabel('Highway miles per gallon') ax.set_ylabel(
Text(0, 0.5, 'Highway miles per gallon')
= (mpg.hwy > 40) | ((mpg.displ > 5) & (mpg.hwy > 20))
outliers = ~outliers inliers
= np.where(outliers)[0] inds
inds
array([ 23, 24, 25, 26, 27, 158, 212, 221, 222])
= pd.Series(['#1f77b4', '#ff7f0e']).take(outliers) col
= plt.subplots(figsize=(6, 4))
fig, ax =col)
ax.scatter(mpg.displ, mpg.hwy, c'Engine Displacement (L)')
ax.set_xlabel('Highway miles per gallon')
ax.set_ylabel(
for i in inds:
= mpg.model[i],
ax.annotate(text= (mpg.displ[i], mpg.hwy[i]),
xy = (mpg.displ[i] + 0.2, mpg.hwy[i])) xytext
Plotting datetime and adding annotations
from datetime import datetime
= pd.read_csv("../data/spx.csv", index_col=0, parse_dates=True)
data = data["SPX"] spx
2008, 1, 1)) spx.asof(datetime(
np.float64(1468.36)
= plt.subplots(figsize=(8,4))
fig, ax
=ax, color="black")
spx.plot(ax
= [
crisis_data 2007, 10, 11), "Peak of bull market"),
(datetime(2008, 3, 12), "Bear Stearns Fails"),
(datetime(2008, 9, 15), "Lehman Bankruptcy")
(datetime(
]
for date, label in crisis_data:
# Text to display
ax.annotate(label, =(date, spx.asof(date) + 75), # Point to annotate (arrow points here)
xy=(date, spx.asof(date) + 225), # Text position
xytext=dict(facecolor="black", headwidth=4, width=2,headlength=4), # Arrow properties
arrowprops="left", verticalalignment="top",
horizontalalignment
)
# Zoom in on 2007-2010
"1/1/2007", "1/1/2011"])
ax.set_xlim([600, 1800])
ax.set_ylim([
"Important dates in the 2008–2009 financial crisis") ax.set_title(
Text(0.5, 1.0, 'Important dates in the 2008–2009 financial crisis')
Matplotlib configuration
Matplotlib comes configured with color schemes and defaults.
The default behavior can be customized via global parameters governing figure size, subplot spacing, colors, font sizes, grid styles, etc.
All of the current configuration settings are found in the plt.rcParams
dictionary, and they can be restored to their default values by calling the plt.rcdefaults()
function.
plt.rcdefaults()'figure.dpi'] plt.rcParams[
100.0
"figure", dpi=120) plt.rc(
# see what rcParams changes to
'figure.dpi'] plt.rcParams[
120.0
To modify configurations, there are a few options:
- use
plt.rc
to changeplt.rcParams
directly: useful if we want to affect all plots - create a figure with a specific dpi/size:
fig = plt.figure(dpi=120, figsize=[10, 12])
: useful for a single plot - use a context manager
with plt.rc_context({"figure.dpi": 120})
: useful for a series of plots
list(plt.rcParams.keys())[0:5]
['_internal.classic_mode',
'agg.path.chunksize',
'animation.bitrate',
'animation.codec',
'animation.convert_args']
import matplotlib.dates as mdates
import matplotlib
with plt.rc_context(
"figure.figsize": [15, 3],
{"figure.dpi": 110,
"axes.labelsize": 15}), plt.style.context("dark_background"):
'Time'], rates['USD'], color='red')
plt.plot(rates["USD exchange rate")
plt.ylabel("date")
plt.xlabel(
= plt.gca() # this gets current axes
ax
= mdates.DateFormatter("%d-%m-%Y")
my_fmt
ax.xaxis.set_major_formatter(my_fmt)= mdates.DayLocator(interval=5) # major ticks every 5 days
xlocator
ax.xaxis.set_major_locator(xlocator)
ax.minorticks_on()5))
ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(
# plt.xticks affects current active plot
=45, ha="right", rotation_mode="anchor") # anchor: aligns the unrotated text and then rotates the text around the point of alignment.
plt.xticks(rotation
# configure grid lines
="x", alpha=0.5)
plt.grid(axis="y", which="major", linewidth=0.5, linestyle="-", alpha=0.5) plt.grid(axis
Saving figures
From Data Science: A First Introduction with Python (Timbers et al. 2022):
Generally speaking, images come in two flavors: raster formats and vector formats.
Raster images:
- 2D grid of square pixels, each with its own color. Can be lossy or lossless.
- A format is lossy if the image cannot be perfectly re-created when loading. A format is lossless if it can re-create the original image exactly.
- Common file types:
- JPEG (.jpg, .jpeg): lossy, usually used for photographs
- PNG (.png): lossless, usually used for plots / line drawings
- BMP (.bmp): lossless, raw image data, no compression (rarely used)
- TIFF (.tif, .tiff): typically lossless, no compression, used mostly in graphic arts, publishing
Vector images:
- represented as a collection of mathematical objects (lines, surfaces, shapes, curves). When the computer displays the image, it redraws all of the elements using their mathematical formulas.
- Common file types:
- SVG (.svg): general-purpose use
- EPS (.eps), general-purpose use (rarely used)
Note: The portable document format PDF (.pdf) is commonly used to store both raster and vector formats. If you try to open a PDF and it’s taking a long time to load, it may be because there is a complicated vector graphics image that your computer is rendering.
Pros and cons
- Time to load
Raster images of fixed width/height take the same amount of space and time to load; a vector image takes space and time to load depending on how complex the image is.
- Pixelation
Raster images can look pixelated if you zoom in; vector images can be zoomed into without loss of image quality.
Fig. 4.30 from Data Science: A First Introduction with Python
Things to think about
- adjust the size of your figure based on your output goal; you don’t want points to look too small on presentation slides, for example
- adjust your text size to be readable based on your output goal (e.g. presentations require larger text than papers)
with plt.rc_context({'font.size': 10}):
=tips,
sns.relplot(data="total_bill",
x="tip",
y="sex",
hue=1.5,
aspect=10)
height
# change font.size and height to see output
'../fig/tips.png', dpi=300)
plt.savefig('../fig/tips.pdf')
plt.savefig('../fig/tips.svg') plt.savefig(
--------------------------------------------------------------------------- FileNotFoundError Traceback (most recent call last) Cell In[123], line 11 2 sns.relplot(data=tips, 3 x="total_bill", 4 y="tip", 5 hue="sex", 6 aspect=1.5, 7 height=10) 9 # change font.size and height to see output ---> 11 plt.savefig('../fig/tips.png', dpi=300) 12 plt.savefig('../fig/tips.pdf') 13 plt.savefig('../fig/tips.svg') File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/pyplot.py:1243, in savefig(*args, **kwargs) 1240 fig = gcf() 1241 # savefig default implementation has no return, so mypy is unhappy 1242 # presumably this is here because subclasses can return? -> 1243 res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value] 1244 fig.canvas.draw_idle() # Need this if 'transparent=True', to reset colors. 1245 return res File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/figure.py:3490, in Figure.savefig(self, fname, transparent, **kwargs) 3488 for ax in self.axes: 3489 _recursively_make_axes_transparent(stack, ax) -> 3490 self.canvas.print_figure(fname, **kwargs) File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backend_bases.py:2184, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs) 2180 try: 2181 # _get_renderer may change the figure dpi (as vector formats 2182 # force the figure dpi to 72), so we need to set it again here. 2183 with cbook._setattr_cm(self.figure, dpi=dpi): -> 2184 result = print_method( 2185 filename, 2186 facecolor=facecolor, 2187 edgecolor=edgecolor, 2188 orientation=orientation, 2189 bbox_inches_restore=_bbox_inches_restore, 2190 **kwargs) 2191 finally: 2192 if bbox_inches and restore_bbox: File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backend_bases.py:2040, in FigureCanvasBase._switch_canvas_and_return_print_method.<locals>.<lambda>(*args, **kwargs) 2036 optional_kws = { # Passed by print_figure for other renderers. 2037 "dpi", "facecolor", "edgecolor", "orientation", 2038 "bbox_inches_restore"} 2039 skip = optional_kws - {*inspect.signature(meth).parameters} -> 2040 print_method = functools.wraps(meth)(lambda *args, **kwargs: meth( 2041 *args, **{k: v for k, v in kwargs.items() if k not in skip})) 2042 else: # Let third-parties do as they see fit. 2043 print_method = meth File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backends/backend_agg.py:481, in FigureCanvasAgg.print_png(self, filename_or_obj, metadata, pil_kwargs) 434 def print_png(self, filename_or_obj, *, metadata=None, pil_kwargs=None): 435 """ 436 Write the figure to a PNG file. 437 (...) 479 *metadata*, including the default 'Software' key. 480 """ --> 481 self._print_pil(filename_or_obj, "png", pil_kwargs, metadata) File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backends/backend_agg.py:430, in FigureCanvasAgg._print_pil(self, filename_or_obj, fmt, pil_kwargs, metadata) 425 """ 426 Draw the canvas, then save it using `.image.imsave` (to which 427 *pil_kwargs* and *metadata* are forwarded). 428 """ 429 FigureCanvasAgg.draw(self) --> 430 mpl.image.imsave( 431 filename_or_obj, self.buffer_rgba(), format=fmt, origin="upper", 432 dpi=self.figure.dpi, metadata=metadata, pil_kwargs=pil_kwargs) File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/image.py:1634, in imsave(fname, arr, vmin, vmax, cmap, format, origin, dpi, metadata, pil_kwargs) 1632 pil_kwargs.setdefault("format", format) 1633 pil_kwargs.setdefault("dpi", (dpi, dpi)) -> 1634 image.save(fname, **pil_kwargs) File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/PIL/Image.py:2591, in Image.save(self, fp, format, **params) 2589 fp = builtins.open(filename, "r+b") 2590 else: -> 2591 fp = builtins.open(filename, "w+b") 2592 else: 2593 fp = cast(IO[bytes], fp) FileNotFoundError: [Errno 2] No such file or directory: '../fig/tips.png'
Resources
Here are some helpful examples:
Visualization more broadly:
A good textbook: - Fundamentals of Data Visualization, Wilke