Multiple Plots
Last updated
Last updated
In the previous session, we explored how a visual representation of data can help us reach observations about data more quickly than a table representation of the same data. We learned how to work with the pyplot module, which provides a high-level interface to the matplotlib library, to create and customize a line chart of unemployment data. To look for potential seasonality, we started by creating a line chart of unemployment rates from 1948.
In this mission, we'll dive a bit deeper into matplotlib to learn how to create multiple line charts to help us compare monthly unemployment trends across time. The unemployment dataset contains 2 columns:
DATE
: date, always the first of the month. Examples:
1948-01-01
: January 1, 1948.
1948-02-01
: February 1, 1948.
1948-03-01
: March 1, 1948.
1948-12-01
: December 1, 1948.
VALUE
: the corresponding unemployment rate, in percent.
Here's what the first 12 rows look like, which reflect the unemployment rate from January 1948 to December 1948:
DATE
VALUE
1948-01-01
3.4
1948-02-01
3.8
1948-03-01
4.0
1948-04-01
3.9
1948-05-01
3.5
1948-06-01
3.6
1948-07-01
3.6
1948-08-01
3.9
1948-09-01
3.8
1948-10-01
3.7
1948-11-01
3.8
1948-12-01
4.0
Let's practice what you learned in the previous mission.
Instructions
Read unrate.csv
into a DataFrame and assign to unrate
.
Use Pandas.to_datetime to convert the DATE
column into a Series of datetime
values.
Generate a line chart that visualizes the unemployment rates from 1948:
x-values should be the first 12 values in the DATE
column
y-values should be the first 12 values in the VALUE
column
Use pyplot.xticks()
to rotate the x-axis tick labels by 90
degrees.
Use pyplot.xlabel()
to set the x-axis label to "Month"
.
Use pyplot.ylabel()
to set the y-axis label to "Unemployment Rate"
.
Use pyplot.title()
to set the plot title to "Monthly Unemployment Trends, 1948"
.
Display the plot.
When we were working with a single plot, pyplot was storing and updating the state of that single plot. We could tweak the plot just using the functions in the pyplot module. When we want to work with multiple plots, however, we need to be more explicit about which plot we're making changes to. This means we need to understand the matplotlib classes that pyplot uses internally to maintain state so we can interact with them directly. Let's first start by understanding what pyplot was automatically storing under the hood when we create a single plot:
a container for all plots was created (returned as a Figure object)
a container for the plot was positioned on a grid (the plot returned as an Axes object)
visual symbols were added to the plot (using the Axes methods)
A figure acts as a container for all of our plots and has methods for customizing the appearance and behavior for the plots within that container. Some examples include changing the overall width and height of the plotting area and the spacing between plots.
We can manually create a figure by calling pyplot.figure():
Instead of only calling the pyplot function, we assigned its return value to a variable (fig
). After a figure is created, an axes for a single plot containing no data is created within the context of the figure. When rendered without data, the plot will resemble the empty plot from the previous mission. The Axes object acts as its own container for the various components of the plot, such as:
values on the x-axis and y-axis
ticks on the x-axis and y-axis
all visual symbols, such as:
markers
lines
gridlines
While plots are represented using instances of the Axes class, they're also often referred to as subplots in matplotlib. To add a new subplot to an existing figure, use Figure.add_subplot. This will return a new Axes object, which needs to be assigned to a variable:
If we want the figure to contain 2 plots, one above the other, we need to write:
This will create a grid, 2 rows by 1 column, of plots. Once we're done adding subplots to the figure, we display everything using plt.show()
:
Let's create a figure, add subplots to it, and display it.
Instructions
Use plt.figure()
to create a figure and assign to fig
.
Use Figure.add_subplot()
to create two subplots above and below each other
Assign the top Axes object to ax1
.
Assign the bottom Axes object to ax2
.
Use plt.show()
to display the resulting plot.
For each subplot, matplotlib generated a coordinate grid that was similar to the one we generated in the last mission using the plot()
function:
the x-axis and y-axis values ranging from 0.0
to 1.0
no gridlines
no data
The main difference is that this plot ranged from 0.0
to 1.0
instead of from -0.06
to 0.06
, which is a quirk suggested by a difference in default properties.
Now that we have a basic understanding of the important matplotlib classes, we can create multiple plots to compare monthly unemployment trends. If you recall, we need to specify the position of each subplot on a grid. Here's a diagram that demonstrates how a 2 by 2 subplot layout would look like:
When the first subplot is created, matplotlib knows to create a grid with 2 rows and 2 columns. As we add each subplot, we specify the plot number we want returned and the corresponding Axes object is created and returned. In matplotlib, the plot number starts at the top left position in the grid (left-most plot on the top row), moves through the remaining positions in that row, then jumps to the left-most plot in the second row, and so forth.
If we created a grid of 4 subplots but don't create a subplot for each position in the grid, areas without axes are left blank:
To generate a line chart within an Axes object, we need to call Axes.plot() and pass in the data you want plotted:
Like pyplot.plot()
, the Axes.plot()
will accept any iterable object for these parameters, including NumPy arrays and pandas Series objects. It will also generate a line chart by default from the values passed in. Each time we want to generate a line chart, we need to call Axes.plot()
and pass in the data we want to use in that plot.
Instructions
Create 2 line subplots in a 2 row by 1 column layout:
In the top subplot, plot the data from 1948.
For the x-axis, use the first 12 values in the DATE
column.
For the y-axis, use the first 12 values in the VALUE
column.
In the bottom subplot, plot the data from 1949.
For the x-axis, use the values from index 12 to 23 in the DATE
column.
For the y-axis, use the values from index 12 to 23 in the VALUE
column.
Use plt.show()
to display all the plots.
Instead of having to rotate the x-axis tick labels, we were able to horizontally extend the entire plotting area to make the labels more readable. Because the goal is to be able to look for any visual similarities between the lines in the plots, we want the space between the 2 plots to be as small as possible. If we had rotated the labels by 90 degrees instead, like we did in the last mission, we'd need to increase the spacing between the plots to keep them from overlapping. Expanding the plotting area horizontally improved the readability of the x-axis tick labels and minimized the amount of space between the 2 line charts.
If you recall, we generated these 2 line charts because we were interested in looking for any seasonality in the monthly unemployment trends. If you spend some time visually analyzing both line charts, you'll discover that there's no changes in unemployment trends that are occurring in the same month in both years.
On this screen, we're going to visualize data from a few more years to see if we find any evidence for seasonality between those years. Because we're going to need to plot multiple years, we can use a loop so that we're not repeating unnecessary code. In order to generate values for the loop, we'll use Python's range()
function, which produces a list of numbers. Let's start by seeing how the function works:
We provide an integer argument to range()
, and it produces a sequence of integers starting at zero, and going up to (but not including) the argument's value. As an example of how we can use this to produce plots, let's look at how we could produce a plot similar to the one on the previous screen using a loop and range()
:
Let's use this technique to plot five years of data.
Instructions
Set the width of the plotting area to 12
inches and the height to 12
inches.
Generate a grid with 5 rows and 1 column and plot data from the individual years. Start with 1948 in the top subplot and end with 1952 in the bottom subplot.
Use plt.show()
to display the plots.
By adding more line charts, we can look across more years for seasonal trends. This comes at a cost, unfortunately. We now have to visually scan over more space, which is a limitation that we experienced when scanning the table representation of the same data. If you recall, one of the limitations of the table representation we discussed in the previous mission was the amount of time we'd have to spend scanning the table as the number of rows increased significantly.
We can handle the visual overhead each additional plot adds by overlaying the line charts in a single subplot. If we remove the year from the x-axis and just keep the month values, we can use the same x-axis values to plot all of the lines. First, we'll explore how to extract just the month values from the DATE
column, then we'll dive into generating multiple plots on the same coordinate grid.
To extract the month values from the DATE
column and assign them to a new column, we can use the pandas.Series.dt accessor:
Calling pandas.Series.dt.month
returns a Series containing the integer values for each month (e.g. 1
for January, 2
for February, etc.). Under the hood, pandas applies the datetime.date.month attribute from the datetime.date class over each datetime value in the DATE
column, which returns the integer month value. Let's now move onto generating multiple line charts in the same subplot.
In the last mission, we called pyplot.plot()
to generate a single line chart. Under the hood, matplotlib created a figure and a single subplot for this line chart. If we call pyplot.plot()
multiple times, matplotlib will generate the line charts on the single subplot.
If we want to set the dimensions for the plotting area, we can create the figure ourselves first then plot the data. This is because matplotlib first checks if a figure already exists before plotting data. It will only create one if we didn't create a figure.
By default, matplotlib will select a different color for each line. To specify the color ourselves, use the c
parameter when calling plot()
:
You can read about the different ways we can specify colors in matplotlib here.
Instructions
Set the plotting area to a width of 6
inches and a height of 3
inches.
Generate 2 line charts in the base subplot, using the MONTH
column for the x-axis instead of the DATE
column:
One line chart using data from 1948, with the line color set to "red"
.
One line chart using data from 1949, with the line color set to "blue"
.
Use plt.show()
to display the plots.
Let's visualize 5 years worth of unemployment rates on the same subplot.
Instructions
Set the plotting area to a width of 10
inches and a height of 6
inches.
Generate the following plots in the base subplot:
1948: set the line color to "red"
1949: set the line color to "blue"
1950: set the line color to "green"
1951: set the line color to "orange"
1952: set the line color to "black"
Use plt.show()
to display the plots.
How colorful! By plotting all of the lines in one coordinate grid, we got a different perspective on the data. The main thing that sticks out is how the blue and green lines span a larger range of y values (4% to 8% for blue and 4% to 7% for green) while the 3 plots below them mostly range only between 3% and 4%. You can tell from the last sentence that we don't know which line corresponds to which year, because the x-axis now only reflects the month values.
To help remind us which year each line corresponds to, we can add a legend that links each color to the year the line is representing. Here's what a legend for the lines in the last screen could look like:
When we generate each line chart, we need to specify the text label we want each color linked to. The pyplot.plot()
function contains a label
parameter, which we use to set the year value:
We can create the legend using pyplot.legend and specify its location using the loc
parameter:
If we're instead working with multiple subplots, we can create a legend for each subplot by mirroring the steps for each subplot. When we use plt.plot()
and plt.legend()
, the Axes.plot()
and Axes.legend() methods are called under the hood and parameters passed to the calls. When we need to create a legend for each subplot, we can use Axes.legend()
instead.
Let's now add a legend for the plot we generated in the last screen.
Instructions
Modify the code from the last screen that overlaid 5 plots to include a legend. Use the year value for each line chart as the label.
E.g. the plot of 1948 data that uses "red"
for the line color should be labeled "1948"
in the legend.
Place the legend in the "upper left"
corner of the plot.
Display the plot using plt.show()
.
Instead of referring back to the code each time we want to confirm what subset each line corresponds to, we can focus our gaze on the plotting area and use the legend. At the moment, the legend unfortunately covers part of the green line (which represents data from 1950). Since the legend isn't critical to the plot, we should move this outside of the coordinate grid. We'll explore how to do so in a later course because it requires a better understanding of some design principles as well as matplotlib.
Before we wrap up this mission, let's enhance the visualization by adding a title and labels for both axes. To set the title, we use pyplot.title() and pass in a string value:
To set the x-axis and y-axis labels, we use pyplot.xlabel() and pyplot.ylabel(). Both of these functions accept string values.
Instructions
Modify the code from the last screen:
Set the title to "Monthly Unemployment Trends, 1948-1952"
.
Set the x-axis label to "Month, Integer"
.
Set the y-axis label to "Unemployment Rate, Percent"
.
In this session, we learned about the important matplotlib building blocks and used them to experiment with creating multiple line charts. In the next mission, we'll explore plots that allow us to visualize discrete data.
Creating a figure using the pyplot module:
fig = plt.figure()
Adding a subplot to an existing figure with 2 plots and 1 column, one above the other:
Returns a new Axes object, which needs to be assigned to a variable:
ax1 = fig.add_subplot(2, 1, 1) ax2 = fig.add_subplot(2, 1, 2)
Generating a line chart within an Axes object:
ax1.plot(unrate['DATE'][:12], unrate['VALUE'][:12]) ax2.plot(unrate['DATE'][12:24], unrate['VALUE'][12:24])
Changing the dimensions of the figure with the figsize parameter (width x height):
fig = plt.figure(figsize=(12, 5))
Specifying the color for a certain line using the c parameter:
plt.plot(unrate[0:12]['MONTH'], unrate[0:12]['VALUE'], c='red')
Creating a legend using the pyplot module and specifying its location:
plt.legend(loc="upper left")
Setting the title for an Axes object:
ax.set_title('Unemployment Trend, 1948')
A figure acts as a container for all of our plots and has methods for customizing the appearance and behavior for the plots within that container.
Pyplot uses the following when we create a single plot:
A container for all plots was created (returned as a Figure object.)
A container for the plot was positioned on a grid (the plot returned as an Axes object.)
Visual symbols were added to the plot (using the Axes methods.)
With each subplot, matplotlib generates a coordinate grid that was similar to the one we generated using the plot() function:
The x-axis and y-axis values ranging from 0.0 to 1.0.
No gridlines.
No data.