目录
3、__getitem__ 方法根据给定的索引返回数据集中对应索引的样本。
这一部分主要讨论如何自定义数据集。截取项目代码如下。
class FWIDataset(Dataset):
def __init__(self, anno, preload=True, sample_ratio=1, file_size=500,
transform_data=None, transform_label=None):
def __getitem__(self, idx):
return data, label if label is not None else np.array([])
def __len__(self):
return len(self.batches) * self.file_size
自定义数据集名称FWIDataset, 步骤如下:
1、继承 Dataset
类
Dataset 类的出处:from torch.utils.data import Dataset
这里的目的是重写Dataset 的方法
2、__init__
方法初始化数据集
当执行下述代码时,会执行__init__方法。
dataset_valid = FWIDataset(
args.val_anno,
preload=True,
sample_ratio=args.sample_temporal,
file_size=ctx['file_size'],
transform_data=transform_data,
transform_label=transform_label
)
__init__方法如下述所示,
def __init__(self, anno, preload=True, sample_ratio=1, file_size=500,
transform_data=None, transform_label=None):
if not os.path.exists(anno):
print(f'Annotation file {anno} does not exists')
self.preload = preload
self.sample_ratio = sample_ratio
self.file_size = file_size
self.transform_data = transform_data
self.transform_label = transform_label
with open(anno, 'r') as f:
self.batches = f.readlines()
if preload:
self.data_list, self.label_list = [], []
for batch in self.batches:
data, label = self.load_every(batch)
self.data_list.append(data)
if label is not None:
self.label_list.append(label)
3、__getitem__
方法根据给定的索引返回数据集中对应索引的样本。
当执行下述代码时,会执行__getitem__方法。
if args.distributed:
train_sampler = DistributedSampler(dataset_train, shuffle=True)
valid_sampler = DistributedSampler(dataset_valid, shuffle=True)
else:
train_sampler = RandomSampler(dataset_train)
valid_sampler = RandomSampler(dataset_valid)
dataloader_train = DataLoader(
dataset_train, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
pin_memory=True, drop_last=True, collate_fn=default_collate)
dataloader_valid = DataLoader(
dataset_valid, batch_size=args.batch_size,
sampler=valid_sampler, num_workers=args.workers,
pin_memory=True, collate_fn=default_collate)
在给你一个index的时候,对data、label 进行归一化、类型转化(tensor)。
1)地震数据的处理 = log + 极大极小值归一化
log处理:地震数据是x,log处理为log(1 + x),且
数据的正负性不变。对数转换可以减小数据的范围,并且可以使得数据更符合正态分布,从而更适合一些统计分析方法。
极大极小值归一化:将数据归一化为[-1, 1]
2)速度模型的处理:将数据归一化为[-1, 1],这里不是很明白,不是应该归一化到 [0,1]?
def __getitem__(self, idx):
batch_idx, sample_idx = idx // self.file_size, idx % self.file_size
if self.preload:
data = self.data_list[batch_idx][sample_idx]
label = self.label_list[batch_idx][sample_idx] if len(self.label_list) != 0 else None
else:
data, label = self.load_every(self.batches[batch_idx])
data = data[sample_idx]
label = label[sample_idx] if label is not None else None
if self.transform_data:
data = self.transform_data(data)
if self.transform_label and label is not None:
label = self.transform_label(label)
return data, label if label is not None else np.array([])
4、__len__
方法返回数据集的长度(即样本数量)。
def __len__(self):
return len(self.batches) * self.file_size
5、总结
综合起来看,其实就是告诉它所有数据的长度,它每次给你返回一个shuffle过的index,以这个方式遍历数据集,通过 __getitem__(self, index)返回一组你要的(data,label)
补充归一化的内容