目录
1.批数据训练
(1)代码
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
BATCH_SIZE = 8 # 批训练的数据个数
x = torch.linspace(1, 10, 10) # x data (torch tensor)
y = torch.linspace(10, 1, 10) # y data (torch tensor)
# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好:True;不打乱:False)
# num_workers=2, # 多线程来读数据 windows删除此行
)
for epoch in range(5): # 训练所有数据 5 次
for step, (batch_x, batch_y) in enumerate(loader): # 每一步 loader 释放一小批数据用来学习
# 训练...
# 打数据
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
(2)运行结果
Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4. 2. 1. 8.] | batch y: [ 6. 4. 1. 8. 7. 9. 10. 3.]
Epoch: 0 | Step: 1 | batch x: [9. 6.] | batch y: [2. 5.]
Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8. 5. 3. 2.] | batch y: [7. 5. 4. 1. 3. 6. 8. 9.]
Epoch: 1 | Step: 1 | batch x: [1. 9.] | batch y: [10. 2.]
Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10. 3. 9. 1.] | batch y: [ 7. 9. 6. 5. 1. 8. 2. 10.]
Epoch: 2 | Step: 1 | batch x: [8. 7.] | batch y: [3. 4.]
Epoch: 3 | Step: 0 | batch x: [ 4. 10. 9. 8. 7. 6. 1. 2.] | batch y: [ 7. 1. 2. 3. 4. 5. 10. 9.]
Epoch: 3 | Step: 1 | batch x: [5. 3.] | batch y: [6. 8.]
Epoch: 4 | Step: 0 | batch x: [9. 8. 4. 6. 5. 3. 7. 2.] | batch y: [2. 3. 7. 5. 6. 8. 4. 9.]
Epoch: 4 | Step: 1 | batch x: [10. 1.] | batch y: [ 1. 10.]
Process finished with exit code 0