学习莫烦pytorch视频,部分函数写法进行更新
import torch
import torch.utils.data as Data
BATCH_SIZE=8
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
torch_dataset = Data.TensorDataset(x, y)#新版本直接用元组,(data_tensor=x, target_tensor=y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle=False, #是否需要随机打乱True:打乱
num_workers=2,#2个进程或者线程#不够整除
)#变成一小批一小批的
if __name__=='__main__':##加上这个才能开多线程
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
#train
print('Epoch:',epoch,'|step: ',step, '|batch x:', batch_x.numpy(), '|batch y: ', batch_y.numpy())