pytorch学习笔记-各种Dataset的使用

前言-Dataset作用

通常在Dataset中进行数据集的“加载+预处理”,将数据集抽象成Dataset类。
在神经网络训练时通常对一个batch数据进行处理,所以,dataset类数据通常还需送入dataloader中进行batch分片处理或并行加速。

1.TensorDataset

train_dataset = TensorDataset(data,targets)
# train_dataset内部数据形式:(data_i, targets_i)

功能: 用来对 tensor数据 打包,等同于 zip 函数的功能。
用途:通常用于打包 数据 和 标签,返回打包成元组的dataset。
要求:送入该函数的两组 tensor 第一个维度大小必须相等

2.ListDataset

data = [f,e,d,c,b,a]
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=[0.409, 0.421, 0.436], std=[0.219, 0.219, 0.220])
         ])
train_dataset = TransformDataset(ListDataset(data), transform)
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

功能:将list类型数据处理成Dataset类。
用途:一些数据加载后通常用list形式暂存,用该函数转化为dataset类后,可以再送入TransformDataset等函数中进行处理。
要求:list内部没有深层结构,即,不能有多层结构的嵌套。

3.TransformDataset

功能: 对dataset进行transform处理。
用途:自定义transform操作,对dataset中数据进行进一步处理。
要求:第一个参数必须是dataset类数据。

4. 自定义Dataset注意事项

必须继承Dataset类,并实现如下两个函数:

• __getitem__:返回一条数据或一个样本。
		实际调用时,obj[index]等价于obj.__getitem__(index)。
• __len__:返回样本的数量。
 		实际调用时,len(obj)等价于obj.__len__()

示例:

class MyDataset(Dataset):
    # Initialize your data, download, etc.
    def __init__(self, data, targets, transforms):
        self.len = len(data)
        self.data = data
        self.targets = targets
        self.transforms = transforms
    def __getitem__(self, index): # 根据索引返回数据和对应的标签
        r_data = self.transforms(self.data[index])  # 相当于TransformDataset的操作
        return r_data, self.targets[index]  # 相当于打包操作
    def __len__(self):
        return self.len

参考:

PyTorch 小功能之 TensorDataset

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值