代码如下:
# -*- coding: utf-8 -*-
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
class TensorDataset(Data.Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
print('idex',index)
a=tuple(tensor[index] for tensor in self.tensors)
print('a', a)
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
'''先转换成 torch 能识别的 Dataset'''
torch_dataset =TensorDataset(x,y) #Data.TensorDataset(x, y)
#print(torch_dataset[0]) #输出(tensor(1.), tensor(10.))
#print(torch_dataset[1]) #输出(tensor(2.), tensor(9.))
''' 把 dataset 放入 DataLoader'''
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
#num_workers=2, # subprocesses for loading data
)
for epoch in range(3): # train entire dataset 3 times
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
# train your data...
print('ok')
# print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
# batch_x.numpy(), '| batch y: ', batch_y.numpy())
#结果为:
'''idex 4
a (tensor(5.), tensor(6.))
idex 6
a (tensor(7.), tensor(4.))
idex 9
a (tensor(10.), tensor(1.))
idex 2
a (tensor(3.), tensor(8.))
idex 3
a (tensor(4.), tensor(7.))
ok
idex 1
a (tensor(2.), tensor(9.))
idex 0
a (tensor(1.), tensor(10.))
idex 7
a (tensor(8.), tensor(3.))
idex 8
a (tensor(9.), tensor(2.))
idex 5
a (tensor(6.), tensor(5.))
ok
idex 3
a (tensor(4.), tensor(7.))
idex 5
a (tensor(6.), tensor(5.))
idex 6
a (tensor(7.), tensor(4.))
idex 9
a (tensor(10.), tensor(1.))
idex 7
a (tensor(8.), tensor(3.))
ok
idex 4
a (tensor(5.), tensor(6.))
idex 2
a (tensor(3.), tensor(8.))
idex 1
a (tensor(2.), tensor(9.))
idex 0
a (tensor(1.), tensor(10.))
idex 8
a (tensor(9.), tensor(2.))
ok
idex 3
a (tensor(4.), tensor(7.))
idex 1
a (tensor(2.), tensor(9.))
idex 4
a (tensor(5.), tensor(6.))
idex 5
a (tensor(6.), tensor(5.))
idex 9
a (tensor(10.), tensor(1.))
ok
idex 2
a (tensor(3.), tensor(8.))
idex 8
a (tensor(9.), tensor(2.))
idex 0
a (tensor(1.), tensor(10.))
idex 7
a (tensor(8.), tensor(3.))
idex 6
a (tensor(7.), tensor(4.))
ok'''