简述
这个操作很常用。用于做对比区分
代码范式:
ax.set_title(str)
但是结合相对应的子图的设计却有多种的操作方式。比如 实例2,实例3
其中实例2的例子当中,不能将nrows改成10(虽然不知道为什么)。
- 实例1和实例2比较适用于特殊的情况(如果行或者列超过10第一方法就不太好了,但是第二种方法就可行)
- 实例3会更加灵活。(推荐使用)
实例1
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(6, 6.5))
for i in range(4):
ax = plt.subplot(221+i)
alpha = 0.98 / 4 * i + 0.01
ax.set_title('%.3f' % alpha)
t1 = np.arange(0.0, 1.0, 0.01)
for n in [1, 2, 3, 4]:
plt.plot(t1, t1 ** n, label="n=%d" % n)
leg = plt.legend(loc='best', ncol=4, mode="expand", shadow=True)
leg.get_frame().set_alpha(alpha)
plt.savefig('1.png')
plt.show()
实例2
CIFAR10为例:
导入包
import torch
import torchvision
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import pickle
下载并加载数据
DOWNLOAD_CIFAR10 = False
cifar10_root = './cifar10/'
if not (os.path.exists(cifar10_root)) or not os.listdir(cifar10_root):
# not mnist dir or mnist is empyt dir
DOWNLOAD_CIFAR10 = True
train_data = torchvision.datasets.CIFAR10(
root=cifar10_root,
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_CIFAR10,
)
train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)
with open('./cifar10/cifar-10-batches-py/batches.meta', 'rb') as f:
data = pickle.load(f)
调用数据
for step, (x, y) in enumerate(train_loader):
print(x.shape, y.shape)
print(y)
fig = plt.figure(figsize=(10, 10))
fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 1.5))
for i in range(10):
ax = axs[i]
ax.axis("off")
ax.set_title(data['label_names'][y[i]])
ax.imshow(np.transpose(x[i].numpy(), (1, 2, 0)))
plt.savefig('cifar-10.png')
plt.show()
break
实例3
前面的步骤和实例2一模一样,就不重复了。
调用数据
注意plt.axis()
不可以拿出到循环外面来。
for step, (x, y) in enumerate(train_loader):
print(x.shape, y.shape)
print(y)
fig = plt.figure(figsize=(20, 20))
for i in range(100):
ax = plt.subplot(10, 10, i+1)
plt.axis("off")
ax.set_title(data['label_names'][y[i]])
plt.imshow(np.transpose(x[i].numpy(), (1, 2, 0)))
plt.savefig('cifar-100.png')
plt.show()
break
效果类似于:
for step, (x, y) in enumerate(train_loader):
print(x.shape, y.shape)
print(y)
fig = plt.figure(figsize=(20, 20))
for i in range(100):
ax = plt.subplot(10, 10, i+1)
ax.axis("off")
ax.set_title(data['label_names'][y[i]])
plt.imshow(np.transpose(x[i].numpy(), (1, 2, 0)))
plt.savefig('cifar-100.png')
plt.show()
break