一、在pytorch中紧凑画出子图
(1)在一行里画出多张图像和对应标签
1)代码
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display
np.set_printoptions(threshold=100000000)
mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
transform=transforms.ToTensor())
def use_svg_display():
"""Use svg format to display plot in jupyter"""
display.set_matplotlib_formats('svg')
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
use_svg_display()
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(np.squeeze(img.numpy()))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
2)效果展示
色偏原因分析:
(2)以矩阵的形式展示多张图片
1)代码
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=49,
shuffle=True, num_workers=0)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=49,
shuffle=False, num_workers=0)
classes = ('t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot')
def imshow(img):
print(img.shape)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
imshow(torchvision.utils.make_grid(images, nrow=7, padding=1))
print(' '.join('%5s' % classes[labels[j]] for j in range(49)))
2)效果展示
二、在matplotlib中紧凑画出子图
(1)区分 subplot 和 subplots
- plt.subplot()在指定分割子图个数和定位子图时可以使用参数连写的方式如:plt.subplot(221)
- plt.subplots(m,n)返回的值的类型为元组,其中包含两个元素:第一个为一个画布fig,第二个是子图ax,结果是一个mxn的矩阵,调用时要XXX=ax.[i,j]来调用。
(2)代码
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot([2, 1], [3, 4])
ax[0, 1].plot([1, 2], [3, 4])
ax[1, 0].plot([1, 2], [4, 3])
ax[1, 1].plot([1, 2], [3, 4])
plt.show()
(3)效果展示
三、手动将一个文件夹下的图片拼接在一起
import PIL.Image as Image
import os
IMAGES_PATH = './loop_img/'
IMAGES_FORMAT = ['.jpg', '.JPG', '.png']
IMAGE_SIZE = 256
IMAGE_ROW = 2
IMAGE_COLUMN = 2
IMAGE_SAVE_PATH = 'final.jpg'
image_names = [name for name in os.listdir(IMAGES_PATH) for item in IMAGES_FORMAT if
os.path.splitext(name)[1] == item]
if len(image_names) != IMAGE_ROW * IMAGE_COLUMN:
raise ValueError("合成图片的参数和要求的数量不能匹配!")
def image_compose():
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE, IMAGE_ROW * IMAGE_SIZE))
for y in range(1, IMAGE_ROW + 1):
for x in range(1, IMAGE_COLUMN + 1):
from_image = Image.open(IMAGES_PATH + image_names[IMAGE_COLUMN * (y - 1) + x - 1]).resize(
(IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))
return to_image.save(IMAGE_SAVE_PATH)
image_compose()
- 结果
参考:
使用python将多张图片拼接成大图