pytorch的dataset里面的__getitem()__解读

代码如下:

# -*- coding: utf-8 -*-
import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
class TensorDataset(Data.Dataset):
    """Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
    def __getitem__(self, index):
        print('idex',index)
        a=tuple(tensor[index] for tensor in self.tensors)
        print('a', a)
        return tuple(tensor[index] for tensor in self.tensors)
    def __len__(self):
        return self.tensors[0].size(0)     
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
'''先转换成 torch 能识别的 Dataset'''
torch_dataset =TensorDataset(x,y) #Data.TensorDataset(x, y)
#print(torch_dataset[0])     #输出(tensor(1.), tensor(10.))
#print(torch_dataset[1])     #输出(tensor(2.), tensor(9.))
''' 把 dataset 放入 DataLoader'''
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # 要不要打乱数据 (打乱比较好)
    #num_workers=2,              # subprocesses for loading data
)
for epoch in range(3):   # train entire dataset 3 times
    for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
        # train your data...
        print('ok')
       # print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
             # batch_x.numpy(), '| batch y: ', batch_y.numpy())
    #结果为:         
'''idex 4
a (tensor(5.), tensor(6.))
idex 6
a (tensor(7.), tensor(4.))
idex 9
a (tensor(10.), tensor(1.))
idex 2
a (tensor(3.), tensor(8.))
idex 3
a (tensor(4.), tensor(7.))
ok
idex 1
a (tensor(2.), tensor(9.))
idex 0
a (tensor(1.), tensor(10.))
idex 7
a (tensor(8.), tensor(3.))
idex 8
a (tensor(9.), tensor(2.))
idex 5
a (tensor(6.), tensor(5.))
ok
idex 3
a (tensor(4.), tensor(7.))
idex 5
a (tensor(6.), tensor(5.))
idex 6
a (tensor(7.), tensor(4.))
idex 9
a (tensor(10.), tensor(1.))
idex 7
a (tensor(8.), tensor(3.))
ok
idex 4
a (tensor(5.), tensor(6.))
idex 2
a (tensor(3.), tensor(8.))
idex 1
a (tensor(2.), tensor(9.))
idex 0
a (tensor(1.), tensor(10.))
idex 8
a (tensor(9.), tensor(2.))
ok
idex 3
a (tensor(4.), tensor(7.))
idex 1
a (tensor(2.), tensor(9.))
idex 4
a (tensor(5.), tensor(6.))
idex 5
a (tensor(6.), tensor(5.))
idex 9
a (tensor(10.), tensor(1.))
ok
idex 2
a (tensor(3.), tensor(8.))
idex 8
a (tensor(9.), tensor(2.))
idex 0
a (tensor(1.), tensor(10.))
idex 7
a (tensor(8.), tensor(3.))
idex 6
a (tensor(7.), tensor(4.))
ok'''
  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值