pytorch torch.utils.data.Dataset

应用

from torch.utils.data import DataLoader, Dataset
import torch

class TensorDataset(Dataset):
    # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]
    def __len__(self):
        return self.data_tensor.size(0)

# 生成数据
data_tensor = torch.randn(3, 2)
target_tensor = torch.rand(3)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
tensor_dataset[0] # 等同于(data_tensor[0],target_tensor[0])
# 可返回数据len
len(tensor_dataset)

API

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite getitem(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite len(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

参考:
https://zhuanlan.zhihu.com/p/28200166
https://pytorch.org/docs/master/data.html

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值