GroupBy
abstraction lets us explore relationships within a dataset.
A pivot table is a similar operation that is commonly seen in spreadsheets and other programs that operate on tabular data.
The pivot table takes simple column-wise data as input, and groups the entries into a two-dimensional table that provides a multidimensional summarization of the data.
The difference between pivot tables and GroupBy
can sometimes cause confusion; it helps me to think of pivot tables as essentially a multidimensional version of GroupBy
aggregation.
That is, you split-apply-combine, but both the split and the combine happen across not a one-dimensional index, but across a two-dimensional grid.import numpy as np
import pandas as pd
import seaborn as sns
titanic = sns.load_dataset('titanic')
titanic.head()
GroupBy
operation–for example, let's look at survival rate by gender:titanic.groupby('sex')[['survived']].mean()
GroupBy
, we might proceed using something like this:
we group by class and gender, select survival, apply a mean aggregate, combine the resulting groups, and then unstack the hierarchical index to reveal the hidden multidimensionality. In code:titanic.groupby(['sex', 'class'])['survived'].aggregate('mean').unstack()
GroupBy
is common enough that Pandas includes a convenience routine, pivot_table
, which succinctly handles this type of multi-dimensional aggregation.pivot_table
method of DataFrame
s:titanic.pivot_table('survived', index='sex', columns='class')
groupby
approach, and produces the same result.
As you might expect of an early 20th-century transatlantic cruise, the survival gradient favors both women and higher classes.
First-class women survived with near certainty (hi, Rose!), while only one in ten third-class men survived (sorry, Jack!).GroupBy
, the grouping in pivot tables can be specified with multiple levels, and via a number of options.
For example, we might be interested in looking at age as a third dimension.
We'll bin the age using the pd.cut
function:age = pd.cut(titanic['age'], [0, 18, 80])
titanic.pivot_table('survived', ['sex', age], 'class')
pd.qcut
to automatically compute quantiles:fare = pd.qcut(titanic['fare'], 2)
titanic.pivot_table('survived', ['sex', age], [fare, 'class'])
pivot_table
method of DataFrame
s is as follows:# call signature as of Pandas 0.18
DataFrame.pivot_table(data, values=None, index=None, columns=None,
aggfunc='mean', fill_value=None, margins=False,
dropna=True, margins_name='All')
fill_value
and dropna
, have to do with missing data and are fairly straightforward; we will not show examples of them here.aggfunc
keyword controls what type of aggregation is applied, which is a mean by default.
As in the GroupBy, the aggregation specification can be a string representing one of several common choices (e.g., 'sum'
, 'mean'
, 'count'
, 'min'
, 'max'
, etc.) or a function that implements an aggregation (e.g., np.sum()
, min()
, sum()
, etc.).
Additionally, it can be specified as a dictionary mapping a column to any of the above desired options:titanic.pivot_table(index='sex', columns='class',
aggfunc={'survived':sum, 'fare':'mean'})
values
keyword; when specifying a mapping for aggfunc
, this is determined automatically.margins
keyword:titanic.pivot_table('survived', index='sex', columns='class', margins=True)
margins_name
keyword, which defaults to "All"
.# shell command to download the data:
# !curl -O https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv
births = pd.read_csv('data/births.csv')
births.head()
births['decade'] = 10 * (births['year'] // 10)
births.pivot_table('births', index='decade', columns='gender', aggfunc='sum')
%matplotlib inline
import matplotlib.pyplot as plt
sns.set() # use Seaborn styles
births.pivot_table('births', index='year', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births per year');
plot()
method, we can immediately see the annual trend in births by gender. By eye, it appears that over the past 50 years male births have outnumbered female births by around 5%.quartiles = np.percentile(births['births'], [25, 50, 75])
mu = quartiles[1]
sig = 0.74 * (quartiles[2] - quartiles[0])
query()
method (discussed further in High-Performance Pandas: eval()
and query()
) to filter-out rows with births outside these values:births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')
day
column to integers; previously it had been a string because some columns in the dataset contained the value 'null'
:# set 'day' column to integer; it originally was a string due to nulls
births['day'] = births['day'].astype(int)
# create a datetime index from the year, month, day
births.index = pd.to_datetime(10000 * births.year +
100 * births.month +
births.day, format='%Y%m%d')
births['dayofweek'] = births.index.dayofweek
import matplotlib.pyplot as plt
import matplotlib as mpl
births.pivot_table('births', index='dayofweek',
columns='decade', aggfunc='mean').plot()
plt.gca().set_xticklabels(['Mon', 'Tues', 'Wed', 'Thurs', 'Fri', 'Sat', 'Sun'])
plt.ylabel('mean births by day');
births_by_date = births.pivot_table('births',
[births.index.month, births.index.day])
births_by_date.head()
births_by_date.index = [pd.datetime(2012, month, day)
for (month, day) in births_by_date.index]
births_by_date.head()
plot
method to plot the data. It reveals some interesting trends:# Plot the results
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);