(3)批训练包装器DataLoader
Pytorch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练.
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print("Epoch: ", epoch, " | Step: ", step, " | batchx: ", batch_x.numpy(), " | batchy: ", batch_y.numpy())
运行结果如下:
Epoch: 0 | Step: 0 | batchx: [ 6. 5. 1. 9. 3.] | batchy: [ 5. 6. 10. 2. 8.]
Epoch: 0 | Step: 1 | batchx: [ 10. 2. 7. 8. 4.] | batchy: [ 1. 9. 4. 3. 7.]
Epoch: 1 | Step: 0 | batchx: [ 6. 5. 4. 7. 9.] | batchy: [ 5. 6. 7. 4. 2.]
Epoch: 1 | Step: 1 | batchx: [ 1. 8. 3. 10. 2.] | batchy: [ 10. 3. 8. 1. 9.]
Epoch: 2 | Step: 0 | batchx: [ 5. 6. 7. 9. 3.] | batchy: [ 6. 5. 4. 2. 8.]
Epoch: 2 | Step: 1 | batchx: [ 10. 4. 8. 1. 2.] | batchy: [ 1. 7. 3. 10. 9.]
(4)使用ConvNet训练cifar-10数据集
# -*- coding:utf-8 -*-
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import