Lecture 10
matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python.
Why do we usually import only pyplot then?
Matplotlib is the whole package; matplotlib.pyplot is a module in matplotlib; and pylab is a module that gets installed alongside matplotlib.
Pyplot provides the state-machine interface to the underlying object-oriented plotting library. The state-machine implicitly and automatically creates figures and axes to achieve the desired plot.
Figure - The entire plot (including subplots)
Axes - Subplot attached to a figure, contains the region for plotting data and x & y axis
Axis - Set the scale and limits, generate ticks and ticklabels
Artist - Everything visible on a figure: text, lines, axis, axes, etc.
x = np.linspace(0, 2*np.pi, 30)
y1 = np.sin(x)
y2 = np.cos(x)
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(6, 6)
)
fig.suptitle("Main title")
ax1.plot(x, y1, "--b", label="sin(x)")
ax1.set_title("subplot 1")
ax1.legend()
ax2.plot(x, y2, ".-r", label="cos(x)")
ax2.set_title("subplot 2")
ax2.legend()
x = np.linspace(0, 2*np.pi, 30)
y1 = np.sin(x)
y2 = np.cos(x)
plt.figure(figsize=(6, 6))
plt.suptitle("Main title")
plt.subplot(211)
plt.plot(x, y1, "--b", label="sin(x)")
plt.title("subplot 1")
plt.legend()
plt.subplot(2,1,2)
plt.plot(x, y2, ".-r", label="cos(x)")
plt.title("subplot 2")
plt.legend()
plt.show()
x = np.linspace(-2, 2, 101)
fig, axs = plt.subplots(2, 2, figsize=(5, 5))
fig.suptitle("More subplots")
axs[0,0].plot(x, x, "b", label="linear")
axs[0,1].plot(x, x**2, "r", label="quadratic")
axs[1,0].plot(x, x**3, "g", label="cubic")
axs[1,1].plot(x, x**4, "c", label="quartic")
[ax.legend() for row in axs for ax in row]
x = np.linspace(-2, 2, 101)
fig, axd = plt.subplot_mosaic(
[['upleft', 'right'],
['lowleft', 'right']],
figsize=(5, 5)
)
axd['upleft' ].plot(x, x, "b", label="linear")
axd['lowleft'].plot(x, x**2, "r", label="quadratic")
axd['right' ].plot(x, x**3, "g", label="cubic")
axd['upleft'].set_title("Linear")
axd['lowleft'].set_title("Quadratic")
axd['right'].set_title("Cubic")
For quick formating of plots (scatter and line) format strings are a useful shorthand, generally they use the format '[marker][line][color]'
,
character | shape |
---|---|
. |
point |
, |
pixel |
o |
circle |
v |
triangle down |
^ |
triangle up |
< |
triangle left |
> |
triangle right |
… | + more |
character | line style |
---|---|
- |
solid |
-- |
dashed |
-. |
dash-dot |
: |
dotted |
character | color |
---|---|
b |
blue |
g |
green |
r |
red |
c |
cyan |
m |
magenta |
y |
yellow |
k |
black |
w |
white |
Beyond creating plots for arrays (and lists), addressable objects like dicts and DataFrames can be used via data
,
np.random.seed(19680801)
d = {'x': np.arange(50),
'color': np.random.randint(0, 50, 50),
'size': np.abs(np.random.randn(50)) * 100}
d['y'] = d['x'] + 10 * np.random.randn(50)
plt.figure(figsize=(6, 3))
plt.scatter(
'x', 'y', c='color', s='size',
data=d
)
plt.xlabel("x-axis")
plt.ylabel("y-axis")
plt.show()
To fix the legend clipping we can use the “contrained” layout to adjust automatically,
np.random.seed(19680801)
d = {'x': np.arange(50),
'color': np.random.randint(0, 50, 50),
'size': np.abs(np.random.randn(50)) * 100}
d['y'] = d['x'] + 10 * np.random.randn(50)
plt.figure(
figsize=(6, 3),
layout="constrained"
)
plt.scatter(
'x', 'y', c='color', s='size',
data=d
)
plt.xlabel("x-axis")
plt.ylabel("y-axis")
plt.show()
Data can also come from DataFrame objects or series,
df = pd.DataFrame({
"x": np.random.normal(size=10000)
}).assign(
y = lambda d: np.random.normal(0.75*d.x, np.sqrt(1-0.75**2), size=10000)
)
fig, ax = plt.subplots(figsize=(5,5))
ax.scatter('x', 'y', c='k', data=df, alpha=0.1, s=0.5)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title("Bivariate normal ($\\rho=0.75$)")
Series objects can also be plotted directly, the index is used as the x
axis labels,
Axis scales can be changed via plt.xscale()
, plt.yscale()
, ax.set_xscale()
, or ax.set_yscale()
, supported values are “linear”, “log”, “symlog”, and “logit”.
y = np.sort( np.random.sample(size=1000) )
x = np.arange(len(y))
plt.figure(layout="constrained")
scales = ['linear', 'log', 'symlog', 'logit']
for i, scale in zip(range(4), scales):
plt.subplot(221+i)
plt.plot(x, y)
plt.grid(True)
if scale == 'symlog':
plt.yscale(scale, linthresh=0.01)
else:
plt.yscale(scale)
plt.title(scale)
plt.show()
df = pd.DataFrame({
"cat": ["A", "B", "C", "D", "E"],
"value": np.exp(range(5))
})
plt.figure(figsize=(4, 6), layout="constrained")
plt.subplot(321)
plt.scatter("cat", "value", data=df)
plt.subplot(322)
plt.scatter("value", "cat", data=df)
plt.subplot(323)
plt.plot("cat", "value", data=df)
plt.subplot(324)
plt.plot("value", "cat", data=df)
plt.subplot(325)
b = plt.bar("cat", "value", data=df)
plt.subplot(326)
b = plt.bar("value", "cat", data=df)
plt.show()
df = pd.DataFrame({
"x1": np.random.normal(size=100),
"x2": np.random.normal(1,2, size=100)
})
plt.figure(figsize=(4, 6), layout="constrained")
plt.subplot(311)
h = plt.hist("x1", bins=10, data=df, alpha=0.5)
h = plt.hist("x2", bins=10, data=df, alpha=0.5)
plt.subplot(312)
h = plt.hist(df, alpha=0.5)
plt.subplot(313)
h = plt.hist(df, stacked=True, alpha=0.5)
plt.show()
df = pd.DataFrame({
"x1": np.random.normal(size=100),
"x2": np.random.normal(1,2, size=100),
"x3": np.random.normal(-1,3, size=100)
}).melt()
df
variable value
0 x1 0.085670
1 x1 1.660256
2 x1 1.596326
3 x1 -1.167331
4 x1 0.221311
.. ... ...
295 x3 -0.822684
296 x3 2.081603
297 x3 2.082767
298 x3 -0.046562
299 x3 0.373482
[300 rows x 2 columns]
To the best of your ability recreate the following plot,
Both Series and DataFrame objects have a plot method which can be used to create visualizations - dtypes determine the type of plot produced.
Plot types can be changed via the kind
argument or using one of the DataFrame.plot.<kind>
method,
The pandas library also provides the plotting
submodule with several useful higher level plots,
cov = np.identity(5)
cov[1,2] = cov[2,1] = 0.5
cov[3,0] = cov[0,3] = -0.8
df = pd.DataFrame(
np.random.multivariate_normal(
mean=[0]*5, cov=cov, size=1000
).round(3),
columns = ["x1","x2","x3","x4","x5"]
)
df
x1 x2 x3 x4 x5
0 -0.676 -0.073 0.536 -0.481 0.829
1 0.868 -0.100 0.015 -1.404 0.466
2 0.028 -1.573 -2.680 -1.031 -0.655
3 0.435 -0.571 0.447 -0.424 -1.337
4 0.321 0.295 0.835 -0.262 -0.648
.. ... ... ... ... ...
995 -0.643 1.501 0.245 0.473 0.445
996 1.482 0.903 1.271 -1.003 -0.817
997 0.001 1.001 -0.196 -0.430 -0.767
998 -2.009 0.979 -0.347 1.501 0.670
999 1.703 0.235 1.582 -0.722 0.334
[1000 rows x 5 columns]
Sta 663 - Spring 2023