前言总结
老师要求使用真实并且是我们的数据集,于是使用部分GID数据集。本周简单写了数据集部分代码,等后面写完再贴吧~
刚开始接触pytorch推荐学习土堆老师的视频
PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili
复现项目可以参考bubbliiing大佬
一、数据集的构建
pytorch关于构建数据集有两个重要的类:Dataset和Dataloader
本视频里所使用的都是环境pytorch1.1
1.Dataset:读取数据(图片+标签+索引)【p6,p14】
(1)基础构成
他有三个重要的类:Init,get_item,len
init—初始化数据根路径和文件读取路径
get_item—获取文件索引,return img,target
len—获得数据集长度
在这个视频里,我除了学到了这些不算崭新的知识点,还学到了pycharm控制台的用法,感谢土堆
python console可以更方便的显示修改的数据的类型和具体内容
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
writer = SummaryWriter("logs")
class MyData(Dataset):
def __init__(self, root_dir, image_dir, label_dir, transform):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.transform = transform
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
self.image_list.sort()
self.label_list.sort()
def __getitem__(self, idx):
img_name = self.image_list[idx]
label_name = self.label_list[idx]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
# img = np.array(img)
img = self.transform(img)
sample = {'img': img, 'label': label}
return sample
def __len__(self):
assert len(self.image_list) == len(self.label_list)
return len(self.image_list)
if __name__ == '__main__':
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
root_dir = "dataset/train"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
train_dataset = ants_dataset + bees_dataset
# transforms = transforms.Compose([transforms.Resize(256, 256)])
dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)
writer.add_image('error', train_dataset[119]['img'])
writer.close()
# for i, j in enumerate(dataloader):
# # imgs, labels = j
# print(type(j))
# print(i, j['img'].shape)
# # writer.add_image("train_data_b2", make_grid(j['img']), i)
#
# writer.close()
(2)torchvision中的标准数据集调用
pytorch.org中有一个torchvision模块,集成了许多计算机视觉领域的常用数据集
ctrl+p快捷键显示类所需要的参数
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
2.Dataloader:以批次方式整合数据【p15】
Dataloader是将dataset里的数据通过Dataloader类进行整合
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()
3.Transform:![](https://i-blog.csdnimg.cn/direct/226ed56d60d542aaa384677684c405ad.png)
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
class MyData(Dataset):
def __init__(self, root_dir, image_dir, label_dir, transform=None):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.transform = transform
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
self.image_list.sort()
self.label_list.sort()
def __getitem__(self, idx):
img_name = self.image_list[idx]
label_name = self.label_list[idx]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
if self.transform:
img = transform(img)
return img, label
def __len__(self):
assert len(self.image_list) == len(self.label_list)
return len(self.image_list)
transform = transforms.Compose([transforms.Resize(400), transforms.ToTensor()])
root_dir = "dataset/train"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform=transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform=transform)
4.Tensorboard:
(1)作用:绘制Loss损失函数曲线,显示图片
(2)ctrl+鼠标点击,可以查看pycharm类
(3)一个是要写代码,一个是终端里要运行,运行的时候必须指定文件夹,可以指定端口
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
image_path = "data/train/ants_image/6240329_72c01e663e.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("train", img_array, 1, dataformats='HWC')
# y = 2x
for i in range(100):
writer.add_scalar("y=2x", 3*i, i)
writer.close()
5.两大法宝
查看和使用package
dir():查看工具里有什么
help():询问工具箱怎么使用
e.g. dir(torchvision.datasets) help(torchvision.datasets)
6.实践问题
(1)查看png标签图片真值
因为voc数据集的存储一般是直接将标签以类别数字存储,所以查看数据也在这里查看
首先转换为numpy,然后再查看
注意:PIL中的Image对象没有直接的numpy数组接口,但可以使用numpy的asarray函数转换
注意:PIL中的Image对象没有直接的numpy数组接口,但可以使用numpy的asarray函数转换
from PIL import Image
# 测试图片显示
img_path = "dataset/VOCdevkit/VOC2012/SegmentationClass/2007_000032.png"
img = Image.open(img_path)
# 查看这个图片numpy形式
import numpy as np
label_array = np.asarray(img, dtype=np.uint8)
print(label_array.shape)
print(np.unique(label_array))
print(label_array[200:210, 200:210])
(2)transform在语义分割领域的常见使用方法(来自lbq学长和CSDN)
pytorch有数据增强工具箱transform,其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。
语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。
也许在数据集足够大的时候对数据增强进而扩充数据集并不是必要的,毕竟一个真实的数据集好过一千个虚拟的数据集