手把手教你写FWI代码3:关于数据处理

目录

自定义数据集名称FWIDataset, 步骤如下:

1、继承 Dataset 类

2、__init__ 方法初始化数据集 

 3、__getitem__ 方法根据给定的索引返回数据集中对应索引的样本。

4、__len__ 方法返回数据集的长度(即样本数量)。

5、总结

这一部分主要讨论如何自定义数据集。截取项目代码如下。

class FWIDataset(Dataset):
  
    def __init__(self, anno, preload=True, sample_ratio=1, file_size=500,
                    transform_data=None, transform_label=None):
        
    def __getitem__(self, idx):
        return data, label if label is not None else np.array([])
        
    def __len__(self):
        return len(self.batches) * self.file_size

自定义数据集名称FWIDataset, 步骤如下:

1、继承 Dataset

Dataset 类的出处:from torch.utils.data import Dataset 

这里的目的是重写Dataset 的方法

2、__init__ 方法初始化数据集 

当执行下述代码时,会执行__init__方法。

dataset_valid = FWIDataset(
            args.val_anno,
            preload=True,
            sample_ratio=args.sample_temporal,
            file_size=ctx['file_size'],
            transform_data=transform_data,
            transform_label=transform_label
        )

__init__方法如下述所示,

    def __init__(self, anno, preload=True, sample_ratio=1, file_size=500,
                    transform_data=None, transform_label=None):
        if not os.path.exists(anno):
            print(f'Annotation file {anno} does not exists')
        self.preload = preload
        self.sample_ratio = sample_ratio
        self.file_size = file_size
        self.transform_data = transform_data
        self.transform_label = transform_label
        with open(anno, 'r') as f:
            self.batches = f.readlines()
        if preload: 
            self.data_list, self.label_list = [], []
            for batch in self.batches: 
                data, label = self.load_every(batch)
                self.data_list.append(data)
                if label is not None:
                    self.label_list.append(label)

 3、__getitem__ 方法根据给定的索引返回数据集中对应索引的样本。

当执行下述代码时,会执行__getitem__方法。

if args.distributed:
    train_sampler = DistributedSampler(dataset_train, shuffle=True)
    valid_sampler = DistributedSampler(dataset_valid, shuffle=True)
else:
    train_sampler = RandomSampler(dataset_train)
    valid_sampler = RandomSampler(dataset_valid)

dataloader_train = DataLoader(
    dataset_train, batch_size=args.batch_size,
    sampler=train_sampler, num_workers=args.workers,
    pin_memory=True, drop_last=True, collate_fn=default_collate)

dataloader_valid = DataLoader(
    dataset_valid, batch_size=args.batch_size,
    sampler=valid_sampler, num_workers=args.workers,
    pin_memory=True, collate_fn=default_collate)

在给你一个index的时候,对data、label 进行归一化、类型转化(tensor)。 

1)地震数据的处理 = log  + 极大极小值归一化

log处理:地震数据是x,log处理为log(1 + x),且数据的正负性不变。对数转换可以减小数据的范围,并且可以使得数据更符合正态分布,从而更适合一些统计分析方法。

极大极小值归一化:将数据归一化为[-1, 1]

2)速度模型的处理:将数据归一化为[-1, 1],这里不是很明白,不是应该归一化到 [0,1]?

    def __getitem__(self, idx):
        batch_idx, sample_idx = idx // self.file_size, idx % self.file_size
        if self.preload:
            data = self.data_list[batch_idx][sample_idx]
            label = self.label_list[batch_idx][sample_idx] if len(self.label_list) != 0 else None
        else:
            data, label = self.load_every(self.batches[batch_idx])
            data = data[sample_idx]
            label = label[sample_idx] if label is not None else None
        if self.transform_data:
            data = self.transform_data(data)
        if self.transform_label and label is not None:
            label = self.transform_label(label)
        return data, label if label is not None else np.array([])

4、__len__ 方法返回数据集的长度(即样本数量)。

    def __len__(self):
        return len(self.batches) * self.file_size

5、总结

综合起来看,其实就是告诉它所有数据的长度,它每次给你返回一个shuffle过的index,以这个方式遍历数据集,通过 __getitem__(self, index)返回一组你要的(data,label)

补充归一化的内容

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值