Matplotlib for Python Developers — Save 50%
Build remarkable publication-quality plots the easy way
Plotting data from a CSV file
A common format to export and distribute datasets is the Comma-Separated Values (CSV) format. For example, spreadsheet applications allow us to export a CSV from a working sheet, and some databases also allow for CSV data export. Additionally, it's a common format to distribute datasets on the Web.
In this example, we'll be plotting the evolution of the world's population divided by continents, between 1950 and 2050 (of course they are predictions), using a new type of graph: bars stacked.
Using the data available at http://www.xist.org/earth/pop_continent.aspx (that fetches data from the official UN data at http://esa.un.org/unpp/index.asp), we have prepared the following CSV file:
Continent,1950,1975,2000,2010,2025,2050
Africa,227270,418765,819462,1033043,1400184,1998466
Asia,1402887,2379374,3698296,4166741,4772523,5231485
Europe,547460,676207,726568,732759,729264,691048
Latin America,167307,323323,521228,588649,669533,729184
Northern America,171615,242360,318654,351659,397522,448464
Oceania,12807,21286,31160,35838,42507,51338
In the first line, we can find the header with a description of what the data in the columns represent. The other lines contain the continent's name and its population (in thousands) for the given years.
In the first line, we can find the header with a description of what the data in the columns represent. The other lines contain the continent's name and its population (in thousands) for the given years.
There are several ways to parse a CSV file, for example:
- NumPy's loadtxt() (what we are going to use here)
- Matplotlib's mlab.csv2rec()
- The csv module (in the standard library)
but we decided to go with loadtxt() because it's very powerful (and it's what Matplotlib is standardizing on).
Let's look at how we can plot it then:
# for file opening made easier
from __future__ import with_statement
We need this because we will use the with statement to read the file.
# numpy
import numpy as np
NumPy is used to load the CSV and for its useful array data type.
# matplotlib plotting module
import matplotlib.pyplot as plt
# matplotlib colormap module
import matplotlib.cm as cm
# needed for formatting Y axis
from matplotlib.ticker import FuncFormatter
# Matplotlib font manager
import matplotlib.font_manager as font_manager
In addition to the classic pyplot module, we need other Matplotlib submodules:
- cm (color map): Considering the way we're going to prepare the plot, we need to specify the color map of the graphical elements
- FuncFormatter: We will use this to change the way the Y-axis labels are displayed
- font_manager: We want to have a legend with a smaller font, and font_manager allows us to do that
def billions(x, pos):
"""Formatter for Y axis, values are in billions"""
return '%1.fbn' % (x*1e-6)
This is the function that we will use to format the Y-axis labels. Our data is in thousands. Therefore, by dividing it by one million, we obtain values in the order of billions. The function is called at every label to draw, passing the label value and the position.
# bar width
width = .8
As said earlier, we will plot bars, and here we defi ne their width.
The following is the parsing code. We know that it's a bit hard to follow (the data preparation code is usually the hardest one) but we will show how powerful it is.
# open CSV file
with open('population.csv') as f:
The function we're going to use, NumPy loadtxt(), is able to receive either a filename or a file descriptor, as in this case. We have to open the file here because we have to strip the header line from the rest of the file and set up the data parsing structures.
# read the first line, splitting the years
years = map(int, f.readline().split(',')[1:])
Here we read the first line, the header, and extract the years. We do that by calling the split() function and then mapping the int() function to the resulting list, from the second element onwards (as the first one is a string).
# we prepare the dtype for exacting data; it's made of:
# <1 string field> <len(years) integers fields>
dtype = [('continents', 'S16')] + [('', np.int32)]*len(years)
NumPy is flexible enough to allow us to define new data types. Here, we are creating one ad hoc for our data lines: a string (of maximum 16 characters) and as many integers as the length of years list. Also note how the fi rst element has a name, continents, while the last integers have none: we will need this in a bit.
# we load the file, setting the delimiter and the dtype above
y = np.loadtxt(f, delimiter=',', dtype=dtype)
With the new data type, we can actually call loadtxt(). Here is the description of the parameters:
- f: This is the file descriptor. Please note that it now contains all the lines except the first one (we've read above) which contains the headers, so no data is lost.
- delimiter: By default, loadtxt() expects the delimiter to be spaces, but since we are parsing a CSV file, the separator is comma.
- dtype: This is the data type that is used to apply to the text we read. By default, loadtxt() tries to match against float values
# "map" the resulting structure to be easily accessible:
# the first column (made of string) is called 'continents'
# the remaining values are added to 'data' sub-matrix
# where the real data are
y = y.view(np.dtype([('continents', 'S16'),
('data', np.int32, len(years))]))
Here we're using a trick: we view the resulting data structure as made up of two parts, continents and data. It's similar to the dtype that we defined earlier, but with an important difference. Now, the integer's values are mapped to a field name, data. This results in the column continents with all the continents names,and the matrix data that contains the year's values for each row of the file.
data = y['data']
continents = y['continents']
We can separate the data and the continents part into two variables for easier usage in the code.
# prepare the bottom array
bottom = np.zeros(len(years))
We prepare an array of zeros of the same length as years. As said earlier, we plot stacked bars, so each dataset is plot over the previous ones, thus we need to know where the bars below finish. The bottom array keeps track of this, containing the height of bars already plotted.
# for each line in data
for i in range(len(data)):
Now that we have our information in data, we can loop over it.
# create the bars for each element, on top of the previous bars
bt = plt.bar(range(len(data[i])), data[i], width=width,
color=cm.hsv(32*i), label=continents[i],
bottom=bottom)
and create the stacked bars. Some important notes:
- We select the the i-th row of data, and plot a bar according to its element's size (data[i]) with the chosen width.
- As the bars are generated in different loops, their colors would be all the same. To avoid this, we use a color map (in this case hsv), selecting a different color at each iteration, so the sub-bars will have different colors.
- We label each bar set with the relative continent's name (useful for the legend)
- As we have said, they are stacked bars. In fact, every iteration adds a piece of the global bars. To do so, we need to know where to start drawing the bar from (the lower limit) and bottom does this. It contains the value where to start drowing the current bar.
# update the bottom array
bottom += data[i]
We update the bottom array. By adding the current data line, we know what the bottom line will be to plot the next bars on top of it.
# label the X ticks with years
plt.xticks(np.arange(len(years))+width/2,
[int(year) for year in years])
We then add the tick's labels, the years elements, right in the middle of the bar.
# some information on the plot
plt.xlabel('Years')
plt.ylabel('Population (in billions)')
plt.title('World Population: 1950 - 2050 (predictions)')
Add some information to the graph.
# draw a legend, with a smaller font
plt.legend(loc='upper left',
prop=font_manager.FontProperties(size=7))
We now draw a legend in the upper-left position with a small font (to better fit the empty space).
# apply the custom function as Y axis formatter
plt.gca().yaxis.set_major_formatter(FuncFormatter(billions)
Finally, we change the Y-axis label formatter, to use the custom formatting function that we defined earlier.
The result is the next screenshot where we can see the composition of the world population divided by continents:
In the preceding screenshot, the whole bar represents the total world population, and the sections in each bar tell us about how much a continent contributes to it. Also observe how the custom color map works: from bottom to top, we have represented Africa in red, Asia in orange, Europe in light green, Latin America in green, Northern America in light blue, and Oceania in blue (barely visible as the top of the bars).
Plotting extrapolated data using curve fitting
While plotting the CSV values, we have seen that there were some columns representing predictions of the world population in the coming years. We'd like to show how to obtain such predictions using the mathematical process of extrapolation with the help of curve fitting.
Curve fitting is the process of constructing a curve (a mathematical function) that better fits to a series of data points.
This process is related to other two concepts:
- interpolation: A method of constructing new data points within the range of a known set of points
- extrapolation: A method of constructing new data points outside a known set of points
The results of extrapolation are subject to a greater degree of uncertainty and are influenced a lot by the fitting function that is used.
So it works this way:
- First, a known set of measures is passed to the curve fitting procedure that computes a function to approximate these values
- With this function, we can compute additional values that are not present in the original dataset
Let's first approach curve fitting with a simple example:
# Numpy and Matplotlib
import numpy as np
import matplotlib.pyplot as plt
These are the classic imports.
# the known points set
data = [[2,2],[5,0],[9,5],[11,4],[12,7],[13,11],[17,12]]
This is the data we will use for curve fitting. They are the points on a plane (so each has a X and a Y component)
# we extract the X and Y components from previous points
x, y = zip(*data)
We aggregate the X and Y components in two distinct lists.
# plot the data points with a black cross
plt.plot(x, y, 'kx')
Then plot the original dataset as a black cross on the Matplotlib image.
# we want a bit more data and more fine grained for
# the fitting functions
x2 = np.arange(min(x)-1, max(x)+1, .01)
We prepare a new array for the X values because we wish to have a wider set of values (one unit on the right and one on to the left of the original list) and a fine grain to plot the fitting function nicely.
# lines styles for the polynomials
styles = [':', '-.', '--']
To differentiate better between the polynomial lines, we now define their styles list.
# getting style and count one at time
for d, style in enumerate(styles):
Then we loop over that list by also considering the item count.
# degree of the polynomial
deg = d + 1
We define the actual polynomial degree.
# calculate the coefficients of the fitting polynomial
c = np.polyfit(x, y, deg)
Then compute the coefficients of the fitting polynomial whose general format is:
c[0]*x**deg + c[1]*x**(deg – 1) + ... + c[deg]
# we evaluate the fitting function against x2
y2 = np.polyval(c, x2)
Here, we generate the new values by evaluating the fitting polynomial against the x2 array.
# and then we plot it
plt.plot(x2, y2, label="deg=%d" % deg, linestyle=style)
Then we plot the resulting function, adding a label that indicates the degree of the polynomial and using a different style for each line.
# show the legend
plt.legend(loc='upper left')
We then show the legend, and the final result is shown in the next screenshot:
Here, the polynomial with degree=1 is drawn as a dotted blue line, the one with degree=2 is a dash-dot green line, and the one with degree=3 is a dashed red line.
We can see that the higher the degree, the better is the fit of the function against the data.
Let's now revert to our main intention, trying to provide an extrapolation for population data. First a note: we take the values for 2010 as real data and not predictions (well, we are quite near to that year) else we have very few values to create a realistic extrapolation.
Let's see the code:
# for file opening made easier
from __future__ import with_statement
# numpy
import numpy as np
# matplotlib plotting module
import matplotlib.pyplot as plt
# matplotlib colormap module
import matplotlib.cm as cm
# Matplotlib font manager
import matplotlib.font_manager as font_manager
# bar width
width = .8
# open CSV file
with open('population.csv') as f:
# read the first line, splitting the years
years = map(int, f.readline().split(',')[1:])
# we prepare the dtype for exacting data; it's made of:
# <1 string field> <6 integers fields>
dtype = [('continents', 'S16')] + [('', np.int32)]*len(years)
# we load the file, setting the delimiter and the dtype above
y = np.loadtxt(f, delimiter=',', dtype=dtype)
# "map" the resulting structure to be easily accessible:
# the first column (made of string) is called 'continents'
# the remaining values are added to 'data' sub-matrix
# where the real data are
y = y.view(np.dtype([('continents', 'S16'),
('data', np.int32, len(years))]))
# extract fields
data = y['data']
continents = y['continents']
This is the same code that is used for the CSV example (reported here for completeness).
x = years[:-2]
x2 = years[-2:]
We are dividing the years into two groups: before and after 2010. This translates to split the last two elements of the years list.
What we are going to do here is prepare the plot in two phases:
- First, we plot the data we consider certain values
- After this, we plot the data from the UN predictions next to our extrapolations
# prepare the bottom array
b1 = np.zeros(len(years)-2)
We prepare the array (made of zeros) for the bottom argument of bar().
# for each line in data
for i in range(len(data)):
# select all the data except the last 2 values
d = data[i][:-2]
For each data line, we extract the information we need, so we remove the last two values.
# create bars for each element, on top of the previous bars
bt = plt.bar(range(len(d)), d, width=width,
color=cm.hsv(32*(i)), label=continents[i],
bottom=b1)
# update the bottom array
b1 += d
Then we plot the bar, and update the bottom array.
# prepare the bottom array
b2_1, b2_2 = np.zeros(2), np.zeros(2)
We need two arrays because we will display two bars for the same year—one from the CSV and the other from our fitting function.
# for each line in data
for i in range(len(data)):
# extract the last 2 values
d = data[i][-2:]
Again, for each line in the data matrix, we extract the last two values that are needed to plot the bar for CSV.
# select the data to compute the fitting function
y = data[i][:-2]
Along with the other values needed to compute the fitting polynomial.
# use a polynomial of degree 3
c = np.polyfit(x, y, 3)
Here, we set up a polynomial of degree 3; there is no need for higher degrees.
# create a function out of those coefficients
p = np.poly1d(c)
This method constructs a polynomial starting from the coefficients that we pass as parameter.
# compute p on x2 values (we need integers, so the map)
y2 = map(int, p(x2))
We use the polynomial that was defined earlier to compute its values for x2. We also map the resulting values to integer, as the bar() function expects them for height.
# create bars for each element, on top of the previous bars
bt = plt.bar(len(b1)+np.arange(len(d)), d, width=width/2,
color=cm.hsv(32*(i)), bottom=b2_1)
We draw a bar for the data from the CSV. Note how the width is half of that of the other bars. This is because in the same width we will draw the two sets of bars for a better visual comparison.
# create the bars for the extrapolated values
bt = plt.bar(len(b1)+np.arange(len(d))+width/2, y2,
width=width/2, color=cm.bone(32*(i+2)),
bottom=b2_2)
Here, we plot the bars for the extrapolated values, using a dark color map so that we have an even better separation for the two datasets.
# update the bottom array
b2_1 += d
b2_2 += y2
We update both the bottom arrays.
# label the X ticks with years
plt.xticks(np.arange(len(years))+width/2,
[int(year) for year in years])
We add the years as ticks for the X-axis.
# draw a legend, with a smaller font
plt.legend(loc='upper left',
prop=font_manager.FontProperties(size=7))
To avoid a very big legend, we used only the labels for the data from the CSV, skipping the interpolated values. We believe it's pretty clear what they're referring to. Here is the screenshot that is displayed on executing this example:
The conclusion we can draw from this is that the United Nations uses a different function to prepare the predictions, especially because they have a continuous set of information, and they can also take into account other environmental circumstances while preparing such predictions.
Tools using Matplotlib
Given that it's has an easy and powerful API, Matplotlib is also used inside other programs and tools when plotting is needed. We are about to present a couple of these tools:
- NetworkX
- Mpmath
NetworkX
NetworkX ( http://networkx.lanl.gov/) is a Python module that contains tools for creating and manipulating (complex) networks, also known as graphs.
A graph is defined as a set of nodes and edges where each edge is associated with two nodes. NetworkX also adds the possibility to associate properties to each node and edge.
NetworkX is not primarily a graph drawing package but, in collaboration with Matplotlib (and also with Graphviz), it's able to show the graph we're working on.
In the example we're going to propose, we will show how to create a random graph and draw it in a circular shape.
# matplotlib
import matplotlib.pyplot as plt
# networkx nodule
import networkx as nx
In addition to pyplot, we also import the networkx module.
# prepare a random graph with n nodes and m edges
n = 16
m = 60
G = nx.gnm_random_graph(n, m)
Here, we set up a graph with 16 nodes and 60 edges, chosen randomly from all the graphs with such characteristics. The graph returned is undirected: edges just connect two nodes, without a direction information (from node A to node B or vice versa).
# prepare a circular layout of nodes
pos = nx.circular_layout(G)
Then we are using a node positioning algorithm, particularly to prepare a circular layout for the nodes of our graphs; the returned variable pos is a 2D array of nodes' positions forming a circular shape.
# define the color to select from the color map
# as n numbers evenly spaced between color map limits
node_color = map(int, np.linspace(0, 255, n))
We want to give a nice coloring to our nodes, so we will use a particular color map, but before that we have to identify what colors of the color map would be assigned to each node. We do this by selecting 16 numbers evenly spaced in the 256 available colors in the color map. We now have a progression of numbers that will result in a nice fading effect in the nodes' colors.
# draw the nodes, specifying the color map and the list of color
nx.draw_networkx_nodes(G, pos,
node_color=node_color, cmap=plt.cm.hsv)
We start drawing the graph from the nodes. We pass the graph object, the position pos to draw nodes in a circular layout, the color map, and the list of colors to be assigned to the nodes.
# add the labels inside the nodes
nx.draw_networkx_labels(G, pos)
We then request to draw the labels for the nodes. They are numbers identifying the nodes plotted inside them.
# draw the edges, using alpha parameter to make them lighter
nx.draw_networkx_edges(G, pos, alpha=0.4)
Finally, we draw the edges between nodes. We also specify the alpha parameter so that they are a little lighter and don't just appear as a complicated web of lines.
# turn off axis elements
plt.axis('off')
We then remove the Matplotlib axis lines and labels. The result is as shown in the next screenshot where the nodes' colors are distributed across the whole color spectrum:
We advise you to look at the examples available on the NetworkX web site. If you like this kind of stuff, then you'll enjoy it for sure.
Mpmath
mpmath (http://code.google.com/p/mpmath/) is a mathematical library, written in pure Python for multiprecision floating-point arithmetic, which means that every calculation done using mpmath can have an arbitrarily high number of precision digits. This is extremely important for fields such as numerical simulation and analysis.
It also contains a high number of mathematical functions, constants, and a library of tools commonly needed in mathematical applications with an astonishing performance.
In conjunction with Matplotlib, mpmath provides a convenient plotting interface to display a function graphically.
It is extremely easy to plot with mpmath and Matplotlib:
In [1]: import mpmath as mp
In [2]: mp.plot(mp.sin, [-6, 6])
In this example, the mpmath plot() method takes the function to plot and the interval where to draw it.
Running this code, the following window pops up:
We can also plot multiple functions at a time and define our own functions too:
In [1]: import mpmath as mp
In [2]: mp.plot([mp.sqrt, lambda x: -0.1*x**3 + x-0.5], [-3, 3])
On executing the preceding code snippet, we get the following screenshot where we have plotted the square root (in blue, upper part) and the function we defined (in red, lower part)0:
To plot more functions, simply provide a list of them to plot(). To define a new function, we use a lambda expression.
Note how the square root plot is done in full lines for positive values of X, while it's dotted in the negative part. This is because for X negatives, the result is a complex number: mpmath represents the real part with dashes and the imaginary part with dots.
Summary
In this article, we have seen several examples of real world Matplotlib usage, including:
- How to plot data read from a database
- How to plot data extracted from a parsed Wikipedia article
- How to plot data from parsing an Apache log file
- How to plot data from a CSV file
- How to plot extrapolated data using a curve fitting polynomial
- How to plot using third-party tools such as NetworkX and mpmath
We hope these practical examples have increased your interest in exploring Matplotlib, if you haven't already explored it!
[ 1 | 2 ]
If you have read this article you may be interested to view :
- Plotting data using Matplotlib: Part 1
- Plotting Geographical Data using Basemap
- Advanced Matplotlib: Part 1
- Advanced Matplotlib: Part 2
About the Author :
Sandro Tosi
Sandro Tosi is a Debian Developer, Open Source evangelist, and Python enthusiast. After completing a B.Sc. in Computer Science from the University of Firenze, he worked as a consultant for an energy multinational as System Analyst and EAI Architect, and now works as System Engineer for one of the biggest and most innovative Italian Internet companies.
Books From Packt