一、plt.subplots(nrows, ncols, ...)
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
上述代码创建了一个有1行3列axes的figure,figure的大小为(12,6),figure的名字为'train'。如下图所示。此时plt指向最右边的ax(因为是最后创建的)。
上述代码等价于:(和上面一样,此时plt指向最右边的ax)。
import matplotlib.pyplot as plt
plt.figure("train", (12, 6))
plt.subplot(1,3,1)
plt.subplot(1,3,2)
plt.subplot(1,3,3)
二、plt当前所指的fig/ax永远是最新创建的fig/ax,在调用plt.xxx函数时,要注意操作的对象是哪一个fig的哪个ax。(但plt.show会显示所有figure)
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)
fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y) # ax也有plot方法
axes[0].set_xlabel('aaa') # ax有set_xlabel方法,没有xlabel方法
plt.xlabel("epoch")
plt.title("Epoch Average Loss")
结果如下:
三、一个fig中新创建的ax可能会覆盖旧的ax
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)
fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y)
axes[0].set_xlabel('aaa')
plt.subplot(1,2,2)
plt.xlabel("epoch")
plt.title("Epoch Average Loss")
结果如下: