[pytorch]构建并加载自己的数据集

[pytorch]构建并加载自己的数据集)

pytorch为我们封装好了很多经典的数据集在torchvision.datasets包里, torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,并且提供了数据集设置的一些重要参数设置,可以通过简单数据集设置来进行数据集的调用。从这些数据集中我们也可以看出数据集设置的主要变量有哪些并且有什么功能对将来自己数据集的设置也有极大的帮助。

在这里简单的举个例子:

torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False#  参数介绍:
 #root(string) - 数据集的根目录在哪里MNIST/processed/training.pt 和 MNIST/processed/test.pt存在。
 # train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
 #download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
 # transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
 # target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换。

torchvision.datasets的具体使用方法详见pytorch官方文档

但是在实践中很多时候我们还是需要设计和加载自己的数据集的,虽然我们可能有现成的数据,比如说图片和它们的标签,但是我们需要将其设计成类方便使用。

必要的函数介绍

为了用pytorch实现自己的数据集,一般是要将自己的数据集设计成一个类,这个类必须包含三个函数:

  1. init(): 这个含糊的参数一般是包含数据所在的文件夹,还有就是对数据进行的transform。
  2. len(): 这个函数不需要参数,一般返回的是数据集的大小。
  3. getitem(): 这个函数的参数一般是索引,返回的是数据集的某一个样本,返回的一般是tensor。

数据集实现示例

以下的示例是逼着自己在实现视频小样本分类的时候为Kinetics数据集实现的一个可以方便被pytorch调用的数据集类,值得一提的是在自己实现类的时候需要继承torch.utils.data.Dataset。

class VideoDataset(Dataset):
    def __init__(self, info_txt, root_dir, mode='train',data_aug=None,transform=None):
        # set params
        self.info_txt=info_txt
        self.root_dir=root_dir
        self.mode=mode
        self.data_aug=data_aug
        self.transform = transform

        # read info_list
        self.info_list=open(self.info_txt).readlines()

    def __len__(self):
        return len(self.info_list)

    def __getitem__(self, idx):
        info_line=self.info_list[idx]
        video_info=info_line.strip('\n')
        video, video_frame_path, video_shape=get_video_from_video_info_2(video_info,mode=self.mode)
        video_label=get_label_from_video_info(video_info,self.info_txt)

        sample = {'video': video, 'label': [int(video_label)],'video_frame_path':video_frame_path,'video_shape':video_shape}

        sample['video'] = torch.FloatTensor(sample['video'])
        sample['label'] = torch.FloatTensor(sample['label'])

        return sample

数据集的加载

数据集的加载分为三部分:数据集的初始化,数据集的load和数据集的使用:

  1. 数据集初始化:
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
  1. 数据集的load,这个只需要调用torch.utils.data.dataloader这个函数:
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
  1. 最后方便理解,笔者贴上了使用的方法:
for batch_index, data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
已标记关键词 清除标记
【课程介绍】       Pytorch项目实战 垃圾分类 课程从实战的角度出发,基于真实数据集与实际业务需求,结合当下最新话题-垃圾分类问题为实际业务出发点,介绍最前沿的深度学习解决方案。     从0到1讲解如何场景业务分析、进行数据处理,模型训练与调优,最后进行测试与结果展示分析。全程实战操作,以最接地气的方式详解每一步流程与解决方案。     课程结合当下深度学习热门领域,尤其是基于facebook 开源分类神器ResNext101网络架构,对网络架构进行调整,以计算机视觉为核心讲解各大网络的应用于实战方法,适合快速入门与进阶提升。 【课程要求】 (1)开发环境:python版本:Python3.7+; torch 版本:1.2.0+; torchvision版本:0.4.0+ (2)开发工具:Pycharm; (3)学员基础:需要一定的Python基础,及深度学习基础; (4)学员收货:掌握最新科技图像分类关键技术; (5)学员资料:内含完整程序源码和数据集; (6)课程亮点:专题技术,完整案例,全程实战操作,徒手撸代码 【课程特色】 阵容强大 讲师一直从事与一线项目开发,高级算法专家,一直从事于图像、NLP、个性化推荐系统热门技术领域。 仅跟前沿 基于当前热门讨论话题:垃圾分类,课程采用学术届和工业届最新前沿技术知识要点。 实战为先 根据实际深度学习工业场景-垃圾分类,从产品需求、产品设计和方案设计、产品技术功能实现、模型上线部署。精心设计工业实战项目 保障效果 项目实战方向包含了学术届和工业届最前沿技术要点 项目包装简历优化 课程内垃圾分类图像实战项目完成后可以直接优化到简历中 【课程思维导图】 【课程实战案例】
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页