从零开始Pytorch深度学习-数据封装

数据封装

深度学习整个的过程由几个模块组成:数据封装->模型定义->损失函数,优化器定义->训练函数定义

这部分主要讲数据封装:

我们以图片分类为例,图1为文件的结构(5种花)

图1

我们的绝对路径是:F:\pythonProject\blog\flower_photos

我们先设计一个函数,用来提取该路径下的所有图片,对应的标签(0-4对应5种花的类别)。

def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    #得到类别
    '''
    os.listdir(root):文件名
    os.path.isdir(os.path.join(root, cla)):判断是否是文件
    '''
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))

    #写入json文件
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    #遍历每一个类别
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))
        #验证集已经确定了,下面是通过验证集来确定训练集
        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label

调用read_split_data

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

返回的信息:训练集的路径,标签;验证集的路径,标签

 图2

我们可以看到图2,所有图片的路径放在一个列表里 。

接下来,我们就需要用dataset和dataloader来封装数据了。

我们先定义Transform(数据的变形)

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

然后,我们定义dataset

'''
    dataset中,传入路径、标签、变形
    '''
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

最后定义datasetloader

train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=0,
                                               collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=0,
                                             collate_fn=val_dataset.collate_fn)

一般设置num_workers=0

这样,我们就把数据封装好了。

备注:

 我们也可以imagefolder类:在torch.utils里有一个imagefolder类可以自动加载一个总文件夹下的各个类的图片和它的对应label,但是没有划分train和test的功能。另一个原因是在初始化的时候必须要声明transform,但是train和test的transform经常不同。

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=data_transform["train"])

注意:imagefolder类虽然方便,但是要自己设计训练,验证集的划分,并且最重要的是,两者的Transform可能不同。提供两种解决方案:

1.直接从源头解决,把上面的文件夹再分成训练,和测试,然后用不同的路径进行划分

2.在使用imagefolder的时候,先不用Transform,如下:

train_dataset = datasets.ImageFolder(root=image_path)
train_size = int(0.8 * len(train_datasets))
test_size = len(train_datasets) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(train_datasets, [train_size, test_size])

train_dataset=train_dataset.dataset#这行很重要
for images, labels in train_dataset.imgs:
    print(images)
    print(labels)

        经过random_split函数之后可以得到两个subset类,但是这个类和imagefolder初始化的对象不同。必须加上train_dataset=train_dataset.dataset才能保持和原来一样的数据格式一致。这时,可以在for循环中单独使用Transform。(个人认为这种方法效率比较低,但也能解决问题)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值