Lecture 5 - Visualization

Overview

We provide general principles for visualization and cover the basics of plotting in python with Seaborn and Matplotlib.

References

This lecture contains material from:

Introduction

Asking questions about data

A visualization helps answer a question about a dataset.

A good visualization will clearly answer your question without distraction; a great visualization can illustrate what the question was without any text explanation.

Types of questions (Leek and Peng, 2015)

From Data Science: A First Introduction with Python:

  • Descriptive: a question that asks about summaries of a dataset without interpretation
  • Exploratory: a question that asks if there are patterns or trends in a single dataset (often used to develop hypotheses)
  • Predictive: a question that asks about predicting measurements for individuals
  • Inferential: a question that looks for patterns or trends and tests for applicability in a wider population
  • Causal: a question that asks whether changing one variable will lead to a change in another variable, on average in a population
  • Mechanistic: a quesetion that asks about the underlying mechanism

Generally we use visualization to help answer descriptive and exploratory questions.

Examples of exploratory questions:

  1. What is the data quality? Are there inconsistencies or illogical values?
  2. What are the distributions of variables?
  3. What associations do there seem to be between variables?

From Gelman et al. (2002): visualizations are most helpful to:

  • facilitate comparisons
  • reveal trends

A famous graph

John Snow’s cholera map (1855):

  • Cholera is an intestinal disease that can cause death within hours of onset of vomiting

  • In 1854, there was a cholera epidemic in Soho, London

  • John Snow, an obstetrician, believed that contaminated water wells were the source of the epidemic

  • This is the famous map he created which showed cholera cases (black rectangles) clustered around the well at intersection of Broad and Cambridge Streets.

Source: wikipedia

Visualization principles

Convey the message

  • make sure visualization answers the question as simply and plainly as possible
  • use legends and labels so visualization is understandable without reading surrounding text
  • ensure data is clearly visable
  • use color schemes understandable by those with color blindness
  • choose scale appropriately (e.g. log scale)

Minimize noise

  • avoid chartjunk! e.g. distracting grid lines, gratuitous use of icons, unnecessary 3D
  • use colors sparingly - too many colors can be distracting
  • if your plot has too many dots or lines and starts to look messy, you need to do something different!
  • don’t adjust axes to zoom in on small differences. If the difference is small, show that it is small!

Figure 24.1: From Data Looks Better Naked by Darkhorse Analytics

No no’s

  • No pie charts!! (They encode quantitative information in angles and areas, which are hard to judge)
  • No stacked bar charts!

Examples of misleading graphs.

Choice of visualization

Summary of basic visualizations (from Data Science: A First Introduction with Python)

  • scatter plots visualize the relationship between two quantitative variables
  • line plots visualize trends with respect to an independent, ordered quantity (e.g., time)
  • bar plots visualize comparisons of amounts
  • histograms visualize the distribution of one quantitative variable (i.e., all its possible values and how often they occur)

Import packages

Let’s now import some packages.

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

We will also use statsmodels (install this package in your msds597 environment).

import statsmodels

Seaborn Relplot

The sns.relplot() function visualizes statistical relationships using either scatter plots or line plots.

The default is a scatter plot. For line plots, use the argument kind='line'.

Let’s first look at the tips data, available in Seaborn.

tips = sns.load_dataset("tips")
tips.head()
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
g = sns.relplot(data=tips,
            x="total_bill",
            y="tip")
g.figure.set_size_inches(5, 3)

sns.relplot(data=tips,
            x="total_bill",
            y="tip",
            hue="time",
            aspect=1.5,
            height=3.5)

sns.relplot(
    data=tips,
    x="total_bill",
    y="tip",
    hue="smoker",
    style="time",
    aspect=1.5,
    height=3.5
)

sns.relplot(
    data=tips,
    x="total_bill",
    y="tip",
    hue="size",
    aspect=1.5,
    height=3.5
)

sns.relplot(data=tips,
            x="total_bill",
            y="tip",
            hue='time',
            size="size",
            aspect=1.5,
            height=3.5)

sns.relplot(data=tips,
            x="total_bill",
            y="tip",
            hue='time',  # colors
            style='day', # markers 
            size="size", # size of the points
            col='sex',   # subplots in columns
            kind='scatter', # lineplot or scatterplot
            row='smoker',
            height=2.5,
            aspect=1.5)  # subplots in the rows

sns.relplot(data=tips,
            x="total_bill",
            y="tip",
            hue='time',
            style='sex',
            size="size",
            aspect=1.5,
            height=3.5)

dowjones = sns.load_dataset("dowjones")
dowjones.head()
Date Price
0 1914-12-01 55.00
1 1915-01-01 56.55
2 1915-02-01 56.00
3 1915-03-01 58.30
4 1915-04-01 66.45
sns.relplot(data=dowjones,
            x="Date",
            y="Price",
            kind="line",
            height=3,
            aspect=3)

fmri = sns.load_dataset("fmri")
fmri.head()
subject timepoint event region signal
0 s13 18 stim parietal -0.017552
1 s5 14 stim parietal -0.080883
2 s12 18 stim parietal -0.081033
3 s11 18 stim parietal -0.046134
4 s10 18 stim parietal -0.037970
sns.relplot(data=fmri,
            x="timepoint",
            y="signal",
            kind="line",
            aspect=2,
            height=3)

sns.relplot(
    data=fmri,
    kind="line",
    x="timepoint",
    y="signal",
    hue="event",
    height=3,
    aspect=2
)

sns.relplot(
    data=fmri,
    kind="line",
    x="timepoint",
    y="signal",
    hue="region",
    style="event",
    height=3,
    aspect=2
)

sns.relplot(
    data=fmri, kind="line",
    x="timepoint", y="signal", hue="region", style="event",
    dashes=False, markers=True,
    height=3,
    aspect=2
)

sns.relplot(
    data=fmri,
    kind="line",
    x="timepoint",
    y="signal",
    row="region",
    col="event",
    height=3,
    aspect=1.5
)

sns.relplot(
    data=fmri, kind="line",
    x="timepoint", y="signal", hue="subject", 
    col="region", row="event", height=3)

Linear regression

Seaborn also has an lmplot function which plots data and a regression model. Let’s explore this function using the mpg dataset.

This dataset contains a subset of the fuel economy data that the EPA makes available on http://fueleconomy.gov. It contains only models which had a new release every year between 1999 and 2008 - this was used as a proxy for the popularity of the car.

  • Format

A data frame with 234 rows and 11 variables:

  1. manufacturer: manufacturer name
  2. model: model name
  3. displ: engine displacement, in litres
  4. year: year of manufacture
  5. cyl: number of cylinders
  6. trans: type of transmission
  7. drv: the type of drive train, where f = front-wheel drive, r = rear wheel drive, 4 = 4wd
  8. cty:city miles per gallon
  9. hwy: highway miles per gallon
  10. fl: fuel type
  11. class: “type” of car
mpg = pd.read_csv('../data/mpg.csv')
mpg.head()
manufacturer model displ year cyl trans drv cty hwy fl class
0 audi a4 1.8 1999 4 auto(l5) f 18 29 p compact
1 audi a4 1.8 1999 4 manual(m5) f 21 29 p compact
2 audi a4 2.0 2008 4 manual(m6) f 20 31 p compact
3 audi a4 2.0 2008 4 auto(av) f 21 30 p compact
4 audi a4 2.8 1999 6 auto(l5) f 16 26 p compact

Some questions we may ask:

  • do cars with big engines use more fuel than cars with small engines?
g = sns.relplot(mpg, x='displ', y='hwy', hue='class')
g.set_axis_labels('Engine Displacement (L)','Highway miles per gallon')

g = sns.lmplot(mpg, x='displ', y='hwy')
g.set_axis_labels('Engine Displacement (L)','Highway miles per gallon')

g = sns.lmplot(mpg, x='displ', y='hwy', hue='drv')
g.set_axis_labels('Engine Displacement (L)','Highway miles per gallon')

We can also use an order-2 regression model. This computes the regression coefficients for the model:

y_i = \beta_0 + \beta_1 x_i + \beta_2 x_i^2 + \varepsilon_i

sns.regplot(mpg, x='displ', y='hwy', order=2)
plt.gca().set_xlabel('Engine Displacement (L)')
plt.gca().set_ylabel('Highway miles per gallon')
Text(0, 0.5, 'Highway miles per gallon')

We can also use locally weighted scatterplot smoothing.

From statsmodels:

Suppose the input data has N points. The algorithm works by estimating y_i by taking the frac*N closest points to (x_i,y_i) based on their x values and estimating y_i using a weighted linear regression. The weight for (x_j,y_j) is tricube function applied to abs(x_i-x_j).

sns.regplot(mpg, x='displ', y='hwy', lowess=True)
plt.gca().set_xlabel('Engine Displacement (L)')
plt.gca().set_ylabel('Highway miles per gallon')
Text(0, 0.5, 'Highway miles per gallon')

Seaborn Displot

Seaborn sns.displot contains functions to visualize the distribution of data.

penguins = sns.load_dataset("penguins")
penguins.head()
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female

Histograms

One of the most common approaches to visualizing a distribution is a histogram.

In a histogram, the x-axis (corresponding to the data variable) is divided into a set of bins and the count of observations falling within each bin is shown using the height of the corresponding bar.

sns.displot(penguins,
            x="flipper_length_mm",
            height=3,
            aspect=1.5)

We can choose the bin size:

sns.displot(penguins,
            x="flipper_length_mm",
            binwidth=7.1,
            height=3,
            aspect=1.5)

Or specify the number of bins we want:

sns.displot(penguins,
            x="flipper_length_mm",
            bins=20,
            height=3,
            aspect=1.5
            )

sns.displot(penguins, x="flipper_length_mm",
            binwidth=0.3,
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            binwidth=30,
            height=3,
            aspect=1.5) # binwdith too big, the two hills in the data are not visible

sns.displot(penguins,
            x="flipper_length_mm",
            bins=15,
            height=3,
            aspect=1.5)

tips = sns.load_dataset("tips")
tips.head()
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
sns.displot(tips,
            x="size",
            discrete=True,
            height=3,
            aspect=1.5)

sns.displot(tips,
            x="day",
            height=3,
            aspect=1.5)
# no need to specify discrete=True beacuse seaborn figures it out on its own

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            col='island',
            height=4)

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            multiple="dodge",
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            col="sex",
            height=3,
            aspect=1.5)

Kernel density estimation

Rather than using discrete bins, a kernel density estimate plot smooths the observations with a kernel.

\widehat{f}_h(x) = \frac{1}{nh}\sum_{i=1}^n K\left(\frac{x - x_i}{h}\right).

The default is to use a Gaussian kernel: \widehat{f}_h(x) = \frac{1}{nh\sigma}\frac{1}{\sqrt{2\pi}}\sum_{i=1}^n \exp\left(-\frac{(x - x_i)^2}{2h^2\sigma^2}\right)

where \sigma is the standard deviation of the sample \{x_1,\dots, x_n\}.

sns.displot(penguins,
            x="flipper_length_mm",
            kind="kde",
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            kind="kde",
            bw_method=0.05,
            height=3,
            aspect=1.5) # setting the bandwidth
# overfitting
# curve is jittery and the jitter is from noise, bandwidth is too small

sns.displot(penguins,
            x="flipper_length_mm",
            kind="kde",
            bw_method=0.3,
            height=3,
            aspect=1.5) # setting the bandwidth

sns.displot(penguins,
            x="flipper_length_mm",
            kind="kde",
            bw_method=2,
            height=3,
            aspect=1.5) # setting the bandwidth
# underfitting:
# bandwidth too big, curve too smoothed out, not informative

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            kind="kde",
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            col='island',
            kind="kde",
            height=3,
            aspect=1.5)

sns.displot(penguins,
            x="flipper_length_mm",
            hue="species",
            kind="kde",
            fill=True,
            height=3,
            aspect=1.5)

Bivariate distributions

# bivariate histogram
sns.displot(penguins,
            x="bill_length_mm",
            y="bill_depth_mm",
            height=4)

sns.displot(penguins,
            x="bill_length_mm",
            y="bill_depth_mm",
            cbar=True,
            height=3,
            aspect=1.25) # adding a colorbar

sns.displot(penguins,
            x="bill_length_mm",
            y="bill_depth_mm",
            hue='species',
            kind='hist',
            height=3) # default is hist

sns.displot(penguins,
            x="bill_length_mm",
            y="bill_depth_mm",
            kind='kde',
            hue='species',
            height=3) 

sns.displot(penguins,
            x="bill_length_mm",
            y="bill_depth_mm",
            hue="species",
            col='island',
            kind="kde",
            height=3)

Plotting joint and marginal distributions

sns.jointplot(data=penguins,
              x="bill_length_mm",
              y="bill_depth_mm",
              hue='species',
              height=4)

sns.jointplot(data=penguins,
              x="bill_length_mm",
              y="bill_depth_mm",
              kind='hist',
              height=4)

sns.jointplot(data=penguins,
              x="bill_length_mm",
              y="bill_depth_mm",
              hue='species',
              kind='kde',
             height=4)

g = sns.jointplot(data=penguins,
                  x="bill_length_mm",
                  y="bill_depth_mm",
                  height=4)
type(g)
seaborn.axisgrid.JointGrid

jointplot() is an interface to the JointGrid class, which has helpful functions like plot_joint and plot_marginal.

g = sns.jointplot(data=penguins,
                  x="bill_length_mm",
                  y="bill_depth_mm",
                  height=4)

g.plot_joint(sns.kdeplot,
             color="red")

# scatter plot in blue
g = sns.jointplot(data=penguins,
                  x="bill_length_mm",
                  y="bill_depth_mm",
                  height=4)

# kde plot in red, same plot
g.plot_joint(sns.kdeplot,
             color="red")

# rug plot in green
g.plot_marginals(sns.rugplot,
                 color="green", height=0.15)

Multiple variables

sns.pairplot(penguins, hue='species',
             height=2.5)

Seaborn Catplot

catplot is similar to relplot, but designed specifically for categorical data.

Generally speaking, there are three different families of categorical plots in Seaborn:

Categorical scatterplots:

  • stripplot() (with kind="strip"; the default)
  • swarmplot() (with kind="swarm")

Categorical distribution plots:

  • boxplot() (with kind="box")
  • violinplot() (with kind="violin")

Categorical estimate plots:

  • pointplot() (with kind="point")
  • barplot() (with kind="bar")
  • countplot() (with kind="count")

Categorical scatterplots

sns.catplot(data=tips,
            x="day",
            y="tip",
           # kind='strip' # default is 'strip'
           jitter=False, # default is True
           height=4.5)

sns.catplot(data=tips,
            x="day",
            y="tip",
            kind="swarm",
            height=4.5)
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:

9.7% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.

/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:

8.0% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.

sns.catplot(data=tips,
            x="day",
            y="tip",
            hue="time",
            kind="swarm",
            height=4.5)
/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:

9.7% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.

/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:

8.0% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.

/Users/gm845/anaconda3/envs/msds597/lib/python3.12/site-packages/seaborn/categorical.py:3399: UserWarning:

6.5% of the points cannot be placed; you may want to decrease the size of the markers or use stripplot.

We can use different color palettes:

https://seaborn.pydata.org/tutorial/color_palettes.html

sns.catplot(data=tips,
            x="day",
            y="total_bill",
            hue="size",
            col='sex',
            palette=sns.color_palette("Blues"),
            height=4)

sns.catplot(data=tips,
            x="total_bill",
            y="day",
            hue="time",
            col='sex',
            height=4)

Categorical distribution plots

sns.catplot(data=tips,
            x="day",
            y="total_bill",
            kind="box",
            height=4)

sns.catplot(data=tips,
            x="day",
            y="total_bill",
            hue="smoker",
            kind="box",
            height=4)

sns.catplot(
    data=tips,
    x="day",
    y="total_bill",
    hue="sex",
    kind='violin'
)

sns.catplot(
    data=tips,
    x="day",
    y="total_bill",
    hue="sex",
    col='time',
    kind="violin",
    split=True,
)

sns.catplot(
    data=tips,
    x="day",
    y="total_bill",
    hue="sex",
    kind="violin",
    inner='stick',
    split=True,
)

sns.catplot(
    data=tips,
    x="day",
    y="total_bill",
    col="sex",
    kind="violin",
    inner='stick',
    split=True,
)

Categorical estimate plots

Barplots take the mean by default.

sns.catplot(data=tips,
            x="day",
            y="total_bill",
            hue="sex",
            # errorbar="ci" - default - uses bootstrapping to compute a confidence interval around the estimate
            kind="bar",
            height=3)

(tips.groupby(['sex', 'day']))['total_bill'].agg(['mean', 'std'])
/var/folders/f0/m7l23y8s7p3_0x04b3td9nyjr2hyc8/T/ipykernel_4800/191020303.py:1: FutureWarning:

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
mean std
sex day
Male Thur 18.714667 8.019728
Fri 19.857000 10.015847
Sat 20.802542 9.836306
Sun 21.887241 9.129142
Female Thur 16.715312 7.759764
Fri 14.145556 4.788547
Sat 19.680357 8.806470
Sun 19.872222 7.837513
sns.catplot(data=tips,
            x="day",
            y="total_bill",
            hue="sex",
            kind="bar",
           errorbar='sd',
           height=3) # interval is +/- 1 sd around the estimate

Here are some more details about the kinds of error bars available in Seaborn.

sns.catplot(
    data=tips,
    x="day",
    hue="sex",
    kind="count", # no calculating mean, just count
    height=4
)

sns.catplot(data=tips,
            x="day",
            y="tip",
            hue="sex",
            kind="point",
            markers=['<', 'o'],
            height=3,
            aspect=1.5)

Let’s see if we can do something more informative.

tips['perc'] = tips['tip'] / tips['total_bill']
tips_subset = tips[['sex', 'day', 'perc']]
tips_mean = tips_subset.groupby('day')['perc'].mean()
tips_mean = pd.DataFrame(tips_mean)
tips_mean.columns = ['mean']
tips_subset = tips_subset.merge(tips_mean, on='day')
tips_subset['diff'] = tips_subset['perc'] - tips_subset['mean']
/var/folders/f0/m7l23y8s7p3_0x04b3td9nyjr2hyc8/T/ipykernel_4800/3189921240.py:3: FutureWarning:

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
tips_subset
sex day perc mean diff
0 Female Sun 0.059447 0.166897 -0.107451
1 Male Sun 0.160542 0.166897 -0.006356
2 Male Sun 0.166587 0.166897 -0.000310
3 Male Sun 0.139780 0.166897 -0.027117
4 Female Sun 0.146808 0.166897 -0.020090
... ... ... ... ... ...
239 Male Sat 0.203927 0.153152 0.050775
240 Female Sat 0.073584 0.153152 -0.079568
241 Male Sat 0.088222 0.153152 -0.064929
242 Male Sat 0.098204 0.153152 -0.054947
243 Female Thur 0.159744 0.161276 -0.001531

244 rows × 5 columns

sns.catplot(tips_subset, x='day', y='diff', hue='sex', kind='box')
plt.axhline(y=0, color='darkgrey', linestyle='dotted', alpha=0.7)

Another categorical example

crashes = sns.load_dataset("car_crashes")
crashes = crashes.sort_values("total", ascending=False)
crashes.head()
total speeding alcohol not_distracted no_previous ins_premium ins_losses abbrev
40 23.9 9.082 9.799 22.944 19.359 858.97 116.29 SC
34 23.9 5.497 10.038 23.661 20.554 688.75 109.72 ND
48 23.8 8.092 6.664 23.086 20.706 992.61 152.56 WV
3 22.4 4.032 5.824 21.056 21.280 827.34 142.39 AR
17 21.4 4.066 4.922 16.692 16.264 872.51 137.13 KY
sns.stripplot(crashes, x='total', 
              y='abbrev')
plt.gca().xaxis.grid(False)
plt.gca().yaxis.grid(True)
plt.gcf().set_size_inches(5, 10)
plt.gca().set(xlim=(0, 25), title='Total crashes', xlabel="", ylabel="")
[(0.0, 25.0),
 Text(0.5, 1.0, 'Total crashes'),
 Text(0.5, 0, ''),
 Text(0, 0.5, '')]

Heatmaps

rates = pd.read_csv('../data/rates.csv')
rates.Time = pd.to_datetime(rates.Time)
corr_mat = rates.corr(numeric_only=True)
corr_mat
USD JPY BGN CZK DKK GBP CHF
USD 1.000000 -0.103337 NaN -0.218888 -0.232007 0.074199 -0.042449
JPY -0.103337 1.000000 NaN 0.655093 0.463510 0.484794 0.901636
BGN NaN NaN NaN NaN NaN NaN NaN
CZK -0.218888 0.655093 NaN 1.000000 0.008358 0.128065 0.649767
DKK -0.232007 0.463510 NaN 0.008358 1.000000 0.307508 0.604572
GBP 0.074199 0.484794 NaN 0.128065 0.307508 1.000000 0.424830
CHF -0.042449 0.901636 NaN 0.649767 0.604572 0.424830 1.000000
corr_mat = corr_mat.drop(index='BGN', columns='BGN')
sns.heatmap(corr_mat)

For continuous scales, we can use the argument cmap (before, we used palette for discrete colors).

sns.heatmap(corr_mat, cmap='RdBu', center = 0)

sns.heatmap(corr_mat, cmap='RdBu_r', center = 0, annot=True)

Clustermap

# Load the brain networks example dataset
df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0)
# Select a subset of the networks
used_networks = [1, 5, 6, 7, 8, 12, 13, 17]
used_columns = (df.columns.get_level_values("network")
                          .astype(int)
                          .isin(used_networks))
df = df.loc[:, used_columns]
sns.heatmap(df.corr(), cmap='vlag', center = 0)

We can plot clusters using sns.clustermap, which plots a matrix dataset as a hierarchically-clustered heatmap.

Hierarchical clustering: 1. treats every data point as its own cluster 2. forms one cluster by combining the two closest clusters 3. repeat step 2. until there is one big cluster

sns.clustermap uses scipy for hierarchical clustering, details here.

# get color palette
network_pal = sns.husl_palette(len(used_networks), s=.45)
# return hues with constant lightness and saturation in the HUSL system.
# s: saturation intensity
network_lut = dict(zip(map(str, used_networks), network_pal))
df.columns.get_level_values("network")
Index(['1', '1', '5', '5', '6', '6', '6', '6', '7', '7', '7', '7', '7', '7',
       '8', '8', '8', '8', '8', '8', '12', '12', '12', '12', '12', '13', '13',
       '13', '13', '13', '13', '17', '17', '17', '17', '17', '17', '17'],
      dtype='object', name='network')
# Convert the palette to vectors that will be drawn on the side of the matrix
networks = df.columns.get_level_values("network")
network_colors = pd.Series(networks, index=df.columns).map(network_lut)
g = sns.clustermap(df.corr(), center=0, cmap="vlag",
                   row_colors=network_colors, col_colors=network_colors,
                   row_cluster = False,
                   dendrogram_ratio=(.1, .2), # how much of plot should be dendrogram
                   cbar_pos=(.02, .32, .03, .2), # (left, bottom, width, height)
                   linewidths=.75, # space between squares
                   figsize=(12, 13))

Colors

Matplotlib master class

So far, we have used Seaborn functions, which are built on matplotlib.

Full customization of Seaborn plots requires some knowledge of matplotlib concepts.

We’ll start with some basic matplotlib functions.

data1_x = np.linspace(-4, 4, num=50)

data1_y1 = np.cos(data1_x)
data1_y2 = np.sin(data1_x)
plt.plot(data1_x, data1_y1)
plt.plot(data1_x, data1_y2)
# plt.show # when matplotlib is used in a terminal or .py script, need this

https://matplotlib.org/stable/api/markers_api.html

plt.plot(data1_x, data1_y1, marker="X")

plt.scatter(data1_x, data1_y1)
plt.plot(data1_x, data1_y1)

  • Call plt.legend() to show labels
  • Call plt.xlabel(my_label) to add a label to the x-axis
  • Call plt.ylabel(my_label) to add a label to the y-axis
  • Call plt.title(my_title) to add a title to the plot
plt.plot(data1_x, data1_y1, label='cosine')
plt.plot(data1_x, data1_y2, label='sine')
plt.legend()

Other useful things

  • plt.legend(loc="lower left", ncol=2): lower|center|upper + left|center|right
  • plt.xlim(left, right)
  • plt.ylim(bottom, top)
plt.plot(data1_x, data1_y1, label='cosine')
plt.plot(data1_x, data1_y2, label='sine')
plt.legend(loc="lower right", ncol=2)

Details

When we call plt.plot(), Matplotlib implicitly creates:

  • a Figure
  • Axes

A Figure is a top-level container that can hold one or more plots. It is like the blank canvas for your plots. It can contain: multiple subplots, colorbars, legends, titles.

An Axes is the plot where the data is displayed, containing the data points, tick marks, grid lines.

To have more control over our plots, we can specify the Figure and Axes directly.

# create a figure object: the container of the plot, and with an axis on it
figure, ax = plt.subplots()
ax.plot(data1_x, data1_y1)
ax.plot(data1_x, data1_y2)

Anatomy of a matplotlib figure

Credits: https://matplotlib.org/stable/gallery/showcase/anatomy.html#anatomy-of-a-figure

Subplots

Credits: https://matplotlib.org/3.1.0/gallery/subplots_axes_and_figures/subplots_demo.html

figure, axes = plt.subplots(
    nrows=2, ncols=2, 
    figsize=(7,4),
    gridspec_kw={'width_ratios': [1, 2]} # specifies width ratios
    )

axes[0,0].plot(data1_x, data1_y1, label='cosine')
axes[1,0].plot(data1_x, data1_y2, label='sine')

axes[0,0].legend()
axes[1,0].legend()

# all the plt.xlabel become ax.set_xlabel ...
axes[0,0].set_xlabel("x")
axes[1,0].set_xlabel("x")

axes[0,0].set_ylabel("y")
axes[1,0].set_ylabel("$y_1$")

axes[0,0].set_title("Cosine Plot")
axes[1,0].set_title("Sine Plot")

axes[0,1].hist(np.random.normal(size=(500)))
axes[1,1].hist(np.random.beta(0.5, 0.5, size=(500)))

axes[0,1].set_ylabel('Density')

figure.suptitle("The suptitle")  # suptitle stands for "super-title" - it describes ALL plots in a figure.

# This automatically adjusts padding to fit plots in figure without overlap
figure.tight_layout()

figure, axes = plt.subplots(
    nrows=1, ncols=2,
    figsize=(5.5,2) 
)

axes[0].plot(data1_x, data1_y1, linestyle=(0, (3, 5, 1, 5)), color='green')
axes[1].plot(data1_x, data1_y2)

extra_ax = figure.add_axes((0.7, 0.7, 0.2,0.2)) # coordinates from bottom left, then size
extra_ax.plot(data1_x, data1_y1)

See more linestyles here.

Subplot Mosaic:

plt.subplot_mosaic takes two lists, each list representing a row, and each element in the list a key representing the column.

(Adapted from https://matplotlib.org/stable/users/explain/axes/arranging_axes.html#basic-2x2-grid)

fig, axd = plt.subplot_mosaic([['upper left', 'right'],
                               ['lower left', 'right']],
                              figsize=(5.5, 3.5), layout="constrained")
for k, ax in axd.items():
    ax.text(0.5, 0.5, 
            f'axd[{k!r}]',   # !r adds quotations 
            ha="center", va="center", 
            fontsize=14, color="darkgrey")
    
fig.suptitle('plt.subplot_mosaic()')
Text(0.5, 0.98, 'plt.subplot_mosaic()')

fig, axd = plt.subplot_mosaic([['upper left', 'right'],
                               ['lower left', 'right']],
                              figsize=(6, 4), layout="constrained")

rates.USD.plot(ax=axd['upper left'])
rates.USD.hist(ax=axd['right'])
rates.reset_index().plot.scatter(x='index', y='USD', ax=axd['lower left'])

Adding multiple lines

Example from https://wesmckinney.com/book/plotting-and-visualization

fig, ax = plt.subplots(figsize=(8, 4))

ax.plot(np.random.randn(1000).cumsum(), color="black", label="one")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dashed",label="two")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dotted",label="three")
ax.legend()

Adding text

fig, ax = plt.subplots(figsize=(6, 3))

ax.plot(np.random.randn(1000).cumsum(), color="black", label="one")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dashed",label="two")
ax.plot(np.random.randn(1000).cumsum(), color="black", linestyle="dotted",label="three")
ax.legend()

ax.text(500, 0, "Hello world!",
        family="monospace", fontsize=10)
Text(500, 0, 'Hello world!')

Back to mpg data:

fig, ax = plt.subplots(figsize=(6, 4))
ax.scatter(mpg.displ, mpg.hwy)
ax.set_xlabel('Engine Displacement (L)')
ax.set_ylabel('Highway miles per gallon')
Text(0, 0.5, 'Highway miles per gallon')

outliers = (mpg.hwy > 40) | ((mpg.displ > 5) & (mpg.hwy > 20))
inliers = ~outliers
inds = np.where(outliers)[0]
inds
array([ 23,  24,  25,  26,  27, 158, 212, 221, 222])
col = pd.Series(['#1f77b4', '#ff7f0e']).take(outliers)
fig, ax = plt.subplots(figsize=(6, 4))
ax.scatter(mpg.displ, mpg.hwy, c=col)
ax.set_xlabel('Engine Displacement (L)')
ax.set_ylabel('Highway miles per gallon')

for i in inds:
    ax.annotate(text= mpg.model[i], 
                xy = (mpg.displ[i], mpg.hwy[i]),
                xytext = (mpg.displ[i] + 0.2, mpg.hwy[i]))

Plotting datetime and adding annotations

from datetime import datetime

data = pd.read_csv("../data/spx.csv", index_col=0, parse_dates=True)
spx = data["SPX"]
spx.asof(datetime(2008, 1, 1))
np.float64(1468.36)
fig, ax = plt.subplots(figsize=(8,4))

spx.plot(ax=ax, color="black")

crisis_data = [
    (datetime(2007, 10, 11), "Peak of bull market"),
    (datetime(2008, 3, 12), "Bear Stearns Fails"),
    (datetime(2008, 9, 15), "Lehman Bankruptcy")
]

for date, label in crisis_data:
    ax.annotate(label,  # Text to display       
                xy=(date, spx.asof(date) + 75), # Point to annotate (arrow points here)
                xytext=(date, spx.asof(date) + 225), # Text position
                arrowprops=dict(facecolor="black", headwidth=4, width=2,headlength=4),  # Arrow properties
                horizontalalignment="left", verticalalignment="top",
                )

# Zoom in on 2007-2010
ax.set_xlim(["1/1/2007", "1/1/2011"])
ax.set_ylim([600, 1800])

ax.set_title("Important dates in the 2008–2009 financial crisis")
Text(0.5, 1.0, 'Important dates in the 2008–2009 financial crisis')

Matplotlib configuration

Matplotlib comes configured with color schemes and defaults.

The default behavior can be customized via global parameters governing figure size, subplot spacing, colors, font sizes, grid styles, etc.

All of the current configuration settings are found in the plt.rcParams dictionary, and they can be restored to their default values by calling the plt.rcdefaults() function.

plt.rcdefaults()
plt.rcParams['figure.dpi']
100.0
plt.rc("figure", dpi=120)
# see what rcParams changes to
plt.rcParams['figure.dpi']
120.0

To modify configurations, there are a few options:

  1. use plt.rc to change plt.rcParams directly: useful if we want to affect all plots
  2. create a figure with a specific dpi/size: fig = plt.figure(dpi=120, figsize=[10, 12]): useful for a single plot
  3. use a context manager with plt.rc_context({"figure.dpi": 120}): useful for a series of plots
list(plt.rcParams.keys())[0:5]
['_internal.classic_mode',
 'agg.path.chunksize',
 'animation.bitrate',
 'animation.codec',
 'animation.convert_args']
import matplotlib.dates as mdates
import matplotlib

with plt.rc_context(
        {"figure.figsize": [15, 3], 
         "figure.dpi": 110, 
         "axes.labelsize": 15}), plt.style.context("dark_background"):
    
    plt.plot(rates['Time'], rates['USD'], color='red')
    plt.ylabel("USD exchange rate")
    plt.xlabel("date")

    ax = plt.gca() # this gets current axes
    
    my_fmt = mdates.DateFormatter("%d-%m-%Y")
    ax.xaxis.set_major_formatter(my_fmt)
    xlocator = mdates.DayLocator(interval=5) # major ticks every 5 days
    ax.xaxis.set_major_locator(xlocator)
    ax.minorticks_on()
    ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(5))

    # plt.xticks affects current active plot
    plt.xticks(rotation=45, ha="right", rotation_mode="anchor") # anchor: aligns the unrotated text and then rotates the text around the point of alignment.

    # configure grid lines
    plt.grid(axis="x", alpha=0.5)
    plt.grid(axis="y", which="major", linewidth=0.5, linestyle="-", alpha=0.5)

Saving figures

From Data Science: A First Introduction with Python (Timbers et al. 2022):

Generally speaking, images come in two flavors: raster formats and vector formats.

Raster images:

  • 2D grid of square pixels, each with its own color. Can be lossy or lossless.
  • A format is lossy if the image cannot be perfectly re-created when loading. A format is lossless if it can re-create the original image exactly.
  • Common file types:
    • JPEG (.jpg, .jpeg): lossy, usually used for photographs
    • PNG (.png): lossless, usually used for plots / line drawings
    • BMP (.bmp): lossless, raw image data, no compression (rarely used)
    • TIFF (.tif, .tiff): typically lossless, no compression, used mostly in graphic arts, publishing

Vector images:

  • represented as a collection of mathematical objects (lines, surfaces, shapes, curves). When the computer displays the image, it redraws all of the elements using their mathematical formulas.
  • Common file types:
    • SVG (.svg): general-purpose use
    • EPS (.eps), general-purpose use (rarely used)

Note: The portable document format PDF (.pdf) is commonly used to store both raster and vector formats. If you try to open a PDF and it’s taking a long time to load, it may be because there is a complicated vector graphics image that your computer is rendering.

Pros and cons

  • Time to load

Raster images of fixed width/height take the same amount of space and time to load; a vector image takes space and time to load depending on how complex the image is.

  • Pixelation

Raster images can look pixelated if you zoom in; vector images can be zoomed into without loss of image quality.

Fig. 4.30 from Data Science: A First Introduction with Python

Things to think about

  • adjust the size of your figure based on your output goal; you don’t want points to look too small on presentation slides, for example
  • adjust your text size to be readable based on your output goal (e.g. presentations require larger text than papers)
with plt.rc_context({'font.size': 10}):
    sns.relplot(data=tips,
            x="total_bill",
            y="tip",
            hue="sex",
            aspect=1.5,
            height=10)
    
    # change font.size and height to see output

    plt.savefig('../fig/tips.png', dpi=300)
    plt.savefig('../fig/tips.pdf')
    plt.savefig('../fig/tips.svg')
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[123], line 11
      2 sns.relplot(data=tips,
      3         x="total_bill",
      4         y="tip",
      5         hue="sex",
      6         aspect=1.5,
      7         height=10)
      9 # change font.size and height to see output
---> 11 plt.savefig('../fig/tips.png', dpi=300)
     12 plt.savefig('../fig/tips.pdf')
     13 plt.savefig('../fig/tips.svg')

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/pyplot.py:1243, in savefig(*args, **kwargs)
   1240 fig = gcf()
   1241 # savefig default implementation has no return, so mypy is unhappy
   1242 # presumably this is here because subclasses can return?
-> 1243 res = fig.savefig(*args, **kwargs)  # type: ignore[func-returns-value]
   1244 fig.canvas.draw_idle()  # Need this if 'transparent=True', to reset colors.
   1245 return res

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/figure.py:3490, in Figure.savefig(self, fname, transparent, **kwargs)
   3488     for ax in self.axes:
   3489         _recursively_make_axes_transparent(stack, ax)
-> 3490 self.canvas.print_figure(fname, **kwargs)

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backend_bases.py:2184, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2180 try:
   2181     # _get_renderer may change the figure dpi (as vector formats
   2182     # force the figure dpi to 72), so we need to set it again here.
   2183     with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2184         result = print_method(
   2185             filename,
   2186             facecolor=facecolor,
   2187             edgecolor=edgecolor,
   2188             orientation=orientation,
   2189             bbox_inches_restore=_bbox_inches_restore,
   2190             **kwargs)
   2191 finally:
   2192     if bbox_inches and restore_bbox:

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backend_bases.py:2040, in FigureCanvasBase._switch_canvas_and_return_print_method.<locals>.<lambda>(*args, **kwargs)
   2036     optional_kws = {  # Passed by print_figure for other renderers.
   2037         "dpi", "facecolor", "edgecolor", "orientation",
   2038         "bbox_inches_restore"}
   2039     skip = optional_kws - {*inspect.signature(meth).parameters}
-> 2040     print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
   2041         *args, **{k: v for k, v in kwargs.items() if k not in skip}))
   2042 else:  # Let third-parties do as they see fit.
   2043     print_method = meth

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backends/backend_agg.py:481, in FigureCanvasAgg.print_png(self, filename_or_obj, metadata, pil_kwargs)
    434 def print_png(self, filename_or_obj, *, metadata=None, pil_kwargs=None):
    435     """
    436     Write the figure to a PNG file.
    437 
   (...)
    479         *metadata*, including the default 'Software' key.
    480     """
--> 481     self._print_pil(filename_or_obj, "png", pil_kwargs, metadata)

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/backends/backend_agg.py:430, in FigureCanvasAgg._print_pil(self, filename_or_obj, fmt, pil_kwargs, metadata)
    425 """
    426 Draw the canvas, then save it using `.image.imsave` (to which
    427 *pil_kwargs* and *metadata* are forwarded).
    428 """
    429 FigureCanvasAgg.draw(self)
--> 430 mpl.image.imsave(
    431     filename_or_obj, self.buffer_rgba(), format=fmt, origin="upper",
    432     dpi=self.figure.dpi, metadata=metadata, pil_kwargs=pil_kwargs)

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/matplotlib/image.py:1634, in imsave(fname, arr, vmin, vmax, cmap, format, origin, dpi, metadata, pil_kwargs)
   1632 pil_kwargs.setdefault("format", format)
   1633 pil_kwargs.setdefault("dpi", (dpi, dpi))
-> 1634 image.save(fname, **pil_kwargs)

File ~/anaconda3/envs/msds597/lib/python3.12/site-packages/PIL/Image.py:2591, in Image.save(self, fp, format, **params)
   2589         fp = builtins.open(filename, "r+b")
   2590     else:
-> 2591         fp = builtins.open(filename, "w+b")
   2592 else:
   2593     fp = cast(IO[bytes], fp)

FileNotFoundError: [Errno 2] No such file or directory: '../fig/tips.png'

Resources

Here are some helpful examples:

Visualization more broadly:

A good textbook: - Fundamentals of Data Visualization, Wilke