Layout and multi-subplots using Python matplotlib¶

Qianjiang Hu¶

2021-11-23¶
update: 2022-11-06¶
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
import matplotlib.gridspec as gridspec
In [2]:
x1 = np.linspace(0.0, 10.0)
x2 = np.linspace(0.0, 2.5)
y1 = np.cos(2 * np.pi * x1)
y2 = np.cos(2 * np.pi * x2)

Before Strating¶

Here introdude axes.flat, enumerate() and itertools.product() before starting the layout tutorial

In [3]:
### axes is an np.array object

fig, axes = plt.subplots(ncols=2,nrows=3, sharex=True, sharey=True)
print(type(axes))
print(axes.shape)
print(axes)
<class 'numpy.ndarray'>
(3, 2)
[[<AxesSubplot: > <AxesSubplot: >]
 [<AxesSubplot: > <AxesSubplot: >]
 [<AxesSubplot: > <AxesSubplot: >]]
In [4]:
### `aexs.flat`
# `aexs.flat` or `aexs.flatten()`: return a 1-D iterator over the aexs array
a = np.array([[2,3],
              [4,5],
              [6,7]])

print(a) 
print('----------')  

print(a.flat)
print('----------')  

for i in a.flat:
    print(i, end = ',')
[[2 3]
 [4 5]
 [6 7]]
----------
<numpy.flatiter object at 0x120814c00>
----------
2,3,4,5,6,7,
In [5]:
## here is the example what axes.flat and axes.flatten() looks like:
fig, axes = plt.subplots(3,2)
axes.flatten()
Out[5]:
array([<AxesSubplot: >, <AxesSubplot: >, <AxesSubplot: >, <AxesSubplot: >,
       <AxesSubplot: >, <AxesSubplot: >], dtype=object)
In [6]:
### `enumerate()`
# enumerate()` # The enumerate() method adds a counter to an iterable and returns it (the enumerate object)
clist = ['b','g','r','c','m','y']
list(enumerate(clist))
Out[6]:
[(0, 'b'), (1, 'g'), (2, 'r'), (3, 'c'), (4, 'm'), (5, 'y')]
In [7]:
### itertools.product()
# The product() method of the itertools module returns the cartesian product of the input iterables.
# import itertools

nrow = 3; ncol = 2
for index, [i,j] in enumerate(itertools.product(range(nrow), range(ncol))):
    print(index)
    print('------')
    
    print(i)
    print('------')
    
    print([i,j])
0
------
0
------
[0, 0]
1
------
0
------
[0, 1]
2
------
1
------
[1, 0]
3
------
1
------
[1, 1]
4
------
2
------
[2, 0]
5
------
2
------
[2, 1]
In [8]:
## plt.tight_layout()  ## Adjust the padding between and around subplots.

plt.subplot(nrows, ncols, idx, sharex = False, sharey = False)¶

  • plt.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, **fig_kw)
In [9]:
# subplots()函数的返回值是一个画布对象fig和一个或多个坐标轴实例对象axis(简写为ax)。
# 当坐标轴实例对象为一个时,其返回值为元组(fig,ax);当坐标轴实例对象为多个时,其返回值为(fig,[ax1,ax2,…]).
In [10]:
t = np.arange(0, 10, 0.01)
ax1 = plt.subplot(211)                 # two row, 1 column, and indicate the plot1
ax1.plot(t, np.sin(2*np.pi*t), 'r--')

ax2 = plt.subplot(212, sharex=ax1)     # two row, 1 column, and indicate the plot2
ax2.plot(t/2, np.sin(4*np.pi*t),'g-')

plt.show()
In [11]:
def f(t):
    return np.exp(-t) * np.cos(2*np.pi*t)

t1 = np.arange(0.0, 3.0, 0.01)

ax1 = plt.subplot(212)      # 2 sub-locations (2r X 1c, set ax1 as the the 2nd, the one at the bottom)
ax1.margins(0.05)           # Default margin is 0.05, value 0 means fit
ax1.plot(t1, f(t1), 'r--')

ax2 = plt.subplot(221)      # 4 sub-locations (2r X 2c, set ax2 as the the 1st, the one at the upperleft)
ax2.margins(2, 2)           # Values >0.0 zoom out
ax2.plot(t1, f(t1), 'g-')
ax2.set_title('Zoomed out')

ax3 = plt.subplot(222)      # 4 sub-locations (2r X 2c, set ax3 as the the 2nd, the one at the upperright)
ax3.margins(x=0, y=-0.25)   # Values in (-0.5, 0.0) zooms in to center
ax3.plot(t1, f(t1), 'b--')
ax3.set_title('Zoomed in')

plt.show()
In [12]:
# plt.subplot(nrows, ncols, index) # plot one axes each time, mixed layout is available

# subplot1 with blue
plt.subplot(2,3,1)
plt.plot(x1,y1, c = 'b')
plt.ylabel('v' + '_'+'b')

# subplot2 with green
plt.subplot(2,3,2)
plt.plot(x1,y1, c = 'g')
plt.ylabel('v' + '_'+'g')

# subplot3 with red
plt.subplot(2,3,3)
plt.plot(x1,y1, c = 'r')
plt.ylabel('v' + '_'+'r')

# subplot4 with cyan
plt.subplot(2,3,4)
plt.plot(x1,y1, c = 'c')
plt.ylabel('v' + '_'+'c')

# subplot5 with magenta
plt.subplot(2,3,5)
plt.plot(x1,y1, c = 'm')
plt.ylabel('v' + '_'+'m')

# subplot6 with yellow
plt.subplot(2,3,6)
plt.plot(x1,y1, c = 'y')
plt.ylabel('v' + '_'+'y')

plt.tight_layout()  ### Adjust the padding between and around subplots to avoid the overlap.
In [13]:
# for loop to creating multiple subplots

t = np.arange(0, 10, 0.01)
v = np.sin(2*np.pi*t)

clist = ['b','g','r','c','m','y']
for i in range(len(clist)):
    plt.subplot(2,3,i+1)
    plt.plot(x1,y1, c = clist[i])
    plt.ylabel('v')
    plt.title(str(i) + '_' + clist[i])

plt.tight_layout()  ### Adjust the padding between and around subplots to avoid the overlap.
In [14]:
# for loop and enumerate to creating multiple subplots

clist = ['b','g','r','c','m','y']
for i, col in enumerate(clist):
    plt.subplot(2,3,i+1)
    plt.plot(x1,y1, c = col)
    plt.ylabel('v')
    plt.title(str(i) + '_' + col)

plt.tight_layout()  ### Adjust the padding between and around subplots to avoid the overlap.
In [15]:
# mixed layout is available

# subplot1 with blue, which will located in left half part
plt.subplot(1,2,1)
plt.plot(x1,y1, c = 'b')
plt.ylabel('v' + '_'+'b')

# subplot2 with green, which will located in right upper panel
plt.subplot(2,2,2)
plt.plot(x1,y1, c = 'g')
plt.ylabel('v' + '_'+'g')

# subplot3 with red, , which will located in right lower panel
plt.subplot(2,2,4)
plt.plot(x1,y1, c = 'r')
plt.ylabel('v' + '_'+'r')


plt.tight_layout()  ### Adjust the padding between and around subplots to avoid the overlap.

plt.subplots(nrows, ncols, sharex = False, sharey = False)¶

In [16]:
fig = plt.figure(figsize=(6,6))  ###
ax = fig.subplots(3,3, sharex=True, sharey=True)
ax[0,1].plot([1,2],[1,3])
ax[1,0].plot([1,2],[1,3])
ax[1,1].plot([1,2],[1,3], '--r')
ax[2,2].plot([1,2],[1,3])

plt.tight_layout() 
In [17]:
# for loop, enumerate and axes.flat to creat multi-subplots
clist = ['b','g','r','c','m','y']

fig, axes = plt.subplots(3,2)
for i, ax in enumerate(axes.flat):
    ax.plot(x1,y1,c = clist[i])
    ax.set_title(str(i) + '_' + clist[i])

plt.tight_layout() 
In [18]:
# for loop, enumerate and axes.flat to creat multi-subplots
clist = ['b','g','r','c','m','y']

import itertools
nrow = 3; ncol = 2
fig, axes = plt.subplots(nrow, ncol, sharex = True, sharey = True)

for a, [i,j] in enumerate(itertools.product(range(nrow), range(ncol))):
    axes[i,j].plot(x1,y1, c = clist[a])
    axes[i,j].set_title(str(i) + '_' + clist[a])

plt.tight_layout() 
In [19]:
# for loop, enumerate and axes.flat to creat multi-subplots
clist = ['b','g','r','c','m','y']

nrow = 3, ncol = 2
fig, axes = plt.subplots(nrow, ncol, sharex = True, sharey = True)

for i, j in zip(nrow,ncol):
    ax[i,j].plot(x1,y1, c = clist[i])
    ax.set_title(str(i) + '_' + clist[i])

plt.tight_layout() 
  Cell In [19], line 4
    nrow = 3, ncol = 2
           ^
SyntaxError: cannot assign to literal

fig.add_axes(), seting the position of axes¶

  • the figure is of Rectangular Coordinates downleft(0,0),upright(1,1)
  • fig.add_axes(rect, projection=None, polar=False, **kwargs)
  • rect: The dimensions [left, bottom, width, height] of the new Axes
  • All quantities are in fractions of figure width and height
In [ ]:
fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.5], frameon = False)  #[left, bottom, width, height]
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.2])

ax1.plot(x1, y1, linestyle = 'dashed', color = 'r')
ax2.plot(x2, y2,linestyle = '--', color = 'g')
plt.show()
In [20]:
# subplot1 and subplot2 share the same x axis
# subplot3 share the same y axis line with subplot1 and subplot2
fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.3, 0.8, 0.5])  # bottom = ax2 bottom + ax2 height
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.2])  # [left, bottom, width, height]
ax3 = fig.add_axes([0.9, 0.1, 0.8, 0.7])  # left = ax1/ax2 left + ax1/ax2 width
ax1.plot(x1, y1, linestyle = 'dashed', color = 'r')
ax2.plot(x2, y2,linestyle = '--', color = 'g')
ax3.plot(x1, y2,linestyle = '-', color = 'purple')
plt.show()

plt.axes([left,bottom,width,height],frameon=True, facecolor='r')¶

plt.axes() 不受 fig.subplots() 影响

In [21]:
fig = plt.figure(figsize=(5,6),facecolor="#9B9B9B") # set backgroud as gray
fig.subplots(2,2) #为了更好地看出axes的效果,这里设置了子区
plt.axes( [0.4,0.4,0.3,0.3], frameon=True, facecolor="r" )
Out[21]:
<Axes: >

The difference between add_axes() and axes()¶

In [22]:
fig = plt.figure(figsize=(4,4),facecolor="#9B9B9B") #背景灰色
fig.subplots(1,2) ##为了更好地看效果,这里设置了子区
fig.add_axes([0.3,0.3, 0.3,0.3], facecolor="red" ) 
plt.axes([0,0,0.3,0.3],facecolor="black")
# 无论是用add_axes还是axes()来添加新坐标轴,它们的执行结果都是一样的。只是add_axes()需要用画布实例fig去调用,不能直接用plt调用。
plt.show()

setup layout by plt.add_subplot()¶

In [23]:
fig = plt.figure(figsize=(4,4),facecolor="#9B9B9B")
plt.subplot(2,1,1)
fig.add_subplot(1,3,3,facecolor="r")  # `fig.add_subplot` is stackable
Out[23]:
<AxesSubplot: >
In [24]:
fig = plt.figure(figsize=(4,4),facecolor="#9B9B9B")
plt.subplot(2,1,1)
plt.subplot(1,3,3,facecolor="r")   # `plt.subplot` is unstackable
plt.show()
/var/folders/7_/v68z3s1d1jd08fr0gkgys5kc0000gn/T/ipykernel_73160/2621669642.py:3: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  plt.subplot(1,3,3,facecolor="r")   # `plt.subplot` is unstackable

The difference between plt.subplots() and fig.add_subplot()¶

In [25]:
# add_subplot()同样需要用画布实例fig去调用,不能直接用plt调用,
# 但是add_subplot()与subplot()相比,在显示结果上明显多了add的功能,
# 因为同时调用两个plt.subplot()是不能增加新子区的(只是后面的plt.subplot()覆盖了前面的,
# 但是通过add_subplot()就可以添加新子区。

plt.subplot2grid(shape, loc, rowspan=1, colspan=1, fig=None, **kwargs)¶

In [26]:
# rowspan : int, Number of rows for the axis to span to the right.
# colspan : int, Number of columns for the axis to span downwards.
In [27]:
fig = plt.figure()
ax1 = plt.subplot2grid((2,2),(0,0),rowspan=2)
ax2 = plt.subplot2grid((2,2),(1,1))
ax1.plot(x1,y2)
ax2.plot(y1,y2)
plt.show()

gs.GridSpec()¶

  • import matplotlib.gridspec as gs
  • gs = gs.GridSpec(rows,cols)
  • ax1 = plt.subplot(gs[locx,locy])
In [28]:
# import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(4,4)

ax1 = plt.subplot(gs[1:3,1])
ax2 = plt.subplot(gs[0,3])
ax3 = plt.subplot(gs[0,:3])
ax4 = plt.subplot(gs[3,0:])
In [29]:
# prepare multioanel plot
fig = plt.figure(1, figsize = (5,5))
# remember, grifdpec is rows, then columns
gs = gridspec.GridSpec(15,15)
gs.update(wspace = 1, hspace = 1)  # adjust the space amomg subplots
# Generate main figure
sub1 = fig.add_subplot(gs[2:15, 0:14])
sub2 = fig.add_subplot(gs[1:2, 0:14])
sub3 = fig.add_subplot(gs[0:1, 0:14])
sub4 = fig.add_subplot(gs[2:15, 14:15])
fig.add_axes([0.3,0.3, 0.3,0.3], facecolor="red" ) 
plt.axes([0,0,0.3,0.3],facecolor="black")
Out[29]:
<Axes: >

Axes.twinx() and Axes.twiny()¶

Axes.twinx() Create a twin Axes sharing the xaxis. Axes.twiny() Create a twin Axes sharing the yaxis.

In [30]:
fig=plt.figure(1)
ax1=plt.subplot(111)
ax2=ax1.twinx()
ax1.plot(np.arange(1,5),'g--')
ax1.set_ylabel('ax1',color='r')
ax2.plot(np.arange(7,10),'b-')
ax2.set_ylabel('ax2',color='b')
plt.show()

Adjust the space among subplots¶

plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)

In [31]:
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(4,4)

ax1 = plt.subplot(gs[1:3,1])
ax2 = plt.subplot(gs[0,3])
ax3 = plt.subplot(gs[0,:3])
ax4 = plt.subplot(gs[3,0:])
In [32]:
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(4,4)

ax1 = plt.subplot(gs[1:3,1])
ax2 = plt.subplot(gs[0,3])
ax3 = plt.subplot(gs[0,:3])
ax4 = plt.subplot(gs[3,0:])
plt.subplots_adjust(wspace=0.001, hspace=0.5) # adjust the space wspace--> horizontal space, hspace--> vertical space

Reference¶

In [33]:
# https://www.cxymm.net/article/weixin_44845160/110160528
# https://www.cxyzjd.com/article/imxlw00/111885442
# https://www.php.cn/python-tutorials-358840.html
# https://zhuanlan.zhihu.com/p/61752314
# https://zhuanlan.zhihu.com/p/404145594
In [34]:
! # https://matplotlib.org/stable/tutorials/intermediate/gridspec.html
! # https://matplotlib.org/stable/tutorials/intermediate/constrainedlayout_guide.html
! # https://matplotlib.org/stable/tutorials/intermediate/tight_layout_guide.html
In [ ]: