"""
训练时,每次计算一批次的数据,然后更新一次神经网络的参数
在代码实现中,会将数据设置为一个迭代器,每次循环给出一批次的数据
pytorch中给出了Dataset,DataLoader两个接口,帮助我们实现数据迭代器
"""
import torch
from torch.utils.data import Dataset, DataLoader
test_list = [(1, 'dog'), (2, 'cat'), (3, 'pig'), (4, 'bird')]
BATCH_SIZE = 2
'''
Dataset 将原始数据转换为python可以识别的数据结构
1.是一个抽象类,所有自写的dataset都必须继承它
2.重写方法_len_用于返回数据的数量
3.子类必须重写方法_getitem_,用于获取数据的索引
'''
class DataSet(Dataset):
def __init__(self, datalist):
self.x = datalist
def __len__(self):
return len(self.x)
def __getitem__(self, item):
return torch.tensor(self.x[item][0]), self.x[item][1]
torch_dataset = DataSet(test_list)
for i in range(len(torch_dataset)):
print(torch_dataset[i])
test_loader = DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
def show_batch():
for epoch in range(3):
print(f'epoch:{epoch}')
for step, (batch_x, batch_y) in enumerate(test_loader, start=0):
print("step:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
show_batch()
输出:
(tensor(1), 'dog')
(tensor(2), 'cat')
(tensor(3), 'pig')
(tensor(4), 'bird')
epoch: 0
step:0, batch_x:tensor([3, 2]), batch_y:('pig', 'cat')
step:1, batch_x:tensor([4, 1]), batch_y:('bird', 'dog')
epoch: 1
step:0, batch_x:tensor([3, 2]), batch_y:('pig', 'cat')
step:1, batch_x:tensor([1, 4]), batch_y:('dog', 'bird')
epoch: 2
step:0, batch_x:tensor([2, 4]), batch_y:('cat', 'bird')
step:1, batch_x:tensor([3, 1]), batch_y:('pig', 'dog')
Process finished with exit code 0