先根据自己的需求构造pytorch数据集
dataset_masked = tdst.ImageFolder(root='./data/CNcroped_face_mask/',transform=img_preprocess)
train_loader = torch.utils.data.DataLoader(dataset_masked , batch_size=BATCH_SIZE,
sampler=train_sampler)
loader里面的数据的shape看一下
# 显示几张
for x in train_loader:
# 设置画布大小
fig = plt.figure(figsize=(4,4))
print(x[0])#torch.Size([196, 3, 112, 112])一次196张图的矩阵,
print(x[1])#torch.Size([196])对应196张图的分类index
break
以上两个矩阵是放在list列表里面的
然后数据在dataloader里面后,将多个数据展示出来查看数据
# 显示几张
for x in train_loader:
# 设置画布大小
fig = plt.figure(figsize=(8,8))
for i in range(16):
# 图片位置
plt.subplot(4,4,i+1)
# 转为numpy
img = x[0][i].numpy()
# 调整 通道顺序为PIL格式
img = np.transpose(img,(1,2,0))
# 先转到[0,1],再乘以255
img =(img + 1 )/ 2 * 255
# 取整
img = img.astype('int')
# 显示
plt.imshow(img)
plt.axis('off')
plt.show()
break