2023 11.11-11.17周报

一、上周工作

理解FCNVMB代码

二、本周计划

pytorch重写Dataset,加载到Dataloader——OpenWFI数据集。仿照FCNVMB来重写dataset部分。

三、完成情况

在pytorch的一些案例学习中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

如果不想用pytorch自带的数据集,想加载自己的数据集怎么办?

那么就要通过重写一个继承了Datasets的MyDataset类来放置自己的数据集(如本次使用的OpenWFI数据集)。简单来说由于Dataloader只认识Dataset形式的数据集,所以如果我们要用自己的数据集,则也要把我们的数据变成那样。

Dataset:

        Dataset本质上就是一个抽象类,可以把数据封装成Python可以识别的数据结构。Dataset类不能实例化,所以在使用Dataset的时候,我们需要定义自己的数据集类,也是Dataset的子类,来继承Dataset类的属性和方法。Dataset可作为DataLoader的参数传入DataLoader,实现基于张量的数据预处理。Dataset主要有两种类型,分别为Map-style datasets和Iterable-style datasets。

        构造一个MyDataset数据类,需要继承Dataset,并重写Dataset中的方法。我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。 __getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

        总而言之,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。即在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。

class MyDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
        self.length = data.shape[0]
        
    def __getitem__(self, idx):
        label = self.label[idx]
        data = self.data[idx]
        return label, data

    def __len__(self):
        return self.length

DataLoader:

Dataset和DataLoader是一起使用的,在模型训练的过程中不断为模型提供数据,同时使用Dataset加载出来的数据集也是DataLoader的第一个参数。所以,DataLoader本质上就是用来将已经加载好的数据以模型能够接收的方式输入到即将训练的模型中去。

通过输入一个数据集,以及常用参数如:batch_size、shuffle,就可以得到一个打包好的迭代器。这个迭代器包含了batch_size的序号及根据batch_size分割好的数据块。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \
    batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \
    drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

参数:

  • dataset:定义的dataset类返回的结果。
  • batchsize:每个bacth要加载的样本数,默认为1。
  • shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False。
  • sample:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;batch_sample类似,表示一次返回一个batch的index。
  • num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程。
  • collate_fn:表示合并样本列表以形成小批量的Tensor对象。
  • pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。
  • drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False。

:通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。

四、存在的主要问题(已解决)

重写时地震数据部分的维度处理

对重写Dataset中的getitem方法中的参数idx的理解

——使用场景:在定义类时,如果希望能按照键取类的值,则需要定义__getitem__方法。

——目的:如果给类定义了__getitem__方法,则当按照键取值时,可以直接返回__getitem__方法执行的结果。

  • 19
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值