在学习莫烦大神的pytorch视频的batch部分,由于pytorch版本更新,产生了一些不兼容的情况。源代码如下:
import torch
import torch.utils.data as Data
torch.manual_seed(1) # 设定随机数种子
BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
torch_dataset=Data.TensorDataset(data_tensor=x,target_tensor=y)
loader=Data.DataLoader(#变成小批数据
dataset=torch_dataset,
batch_size=BATCH_SIZE,#每一组batch里面原数据个数
shuffle=True, #是否将原数据打乱分组
num_workers=2
)
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch)
直接运行会报错,是由于Data.TensorDataset()函数版本更新后接受参数为*tensor,不再设默认值,故只需将对应行改为:
torch_dataset=Data.TensorDataset(x,y)
但是会继续报错:
The “freeze_support()” line can be omitted if the program
is not going to be frozen to produce an executable.
只需把训练过程放在if name