Pytorch官网也给了训练网络实例:
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
这是训练数据最主干的部分,但是训练前还需要一下初始化加载等工作:
- 模块引用:
from torch.nn import MSELoss
from image_loader import *
from GNet import *
import torch.utils.data
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.models as models
- 数据加载
batch_size = 1
dataset = mytraindata(".", transform=True, train=True, rescale=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
- 网络模型初始化及参数初始化
net = GNet()
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
- loss选择
criterion = MSELoss()
- 优化器选择
gparam = list(map(id, net.features.parameters()))
base_params = filter(lambda p : id(p) not in gparam, net.parameters())
# define a optimizer with different learning rate
optimizer = optim.SGD([
{'params': base_params},
{'params': net.features.parameters(), 'lr': 0.00001}], lr=0.001, momentum=0.9)
- 训练迭代:
for epoch in range(100000):
for i, data in enumerate(data_loader, 0):
inputs, labels = data
# if cuda is available, put data on cuda
if torch.cuda.is_available():
inputs = inputs.to(device)
labels = labels.to(device)
# initial grad params as zero each iter, because grad params is accumulation.
optimizer.zero_grad()
# obtain the outputs of the net
outputs = net(inputs)
# calculate the loss between the target and outputs from the net
loss = criterion(outputs, labels)
print(loss, i, epoch)
# backward of the net
loss.backward()
optimizer.step()
- 保存模型:
if epoch % 100 == 99:
model_name = os.path.join('model/params_%d.pkl' % epoch)
torch.save(net.state_dict(), model_name)
完整代码如下:
from torch.nn import MSELoss
from image_loader import *
from GNet import *
import torch.utils.data
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.models as models
#load the dataset
batch_size = 1
dataset = mytraindata(".", transform=True, train=True, rescale=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
# create a Net, with vgg16 pre-trained params and Gaussian variables
net = GNet()
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
#if cuda is available, put net on cuda
if torch.cuda.is_available():
net = net.to(device)
#define a loss function, here using the MSELoss
criterion = MSELoss()
# feeze the params (not update) in get
gparam = list(map(id, net.features.parameters()))
base_params = filter(lambda p : id(p) not in gparam, net.parameters())
# define a optimizer with different learning rate
optimizer = optim.SGD([
{'params': base_params},
{'params': net.features.parameters(), 'lr': 0.00001}], lr=0.001, momentum=0.9)
for epoch in range(100000):
for i, data in enumerate(data_loader, 0):
inputs, labels = data
# if cuda is available, put data on cuda
if torch.cuda.is_available():
inputs = inputs.to(device)
labels = labels.to(device)
# initial grad params as zero each iter, because grad params is accumulation.
optimizer.zero_grad()
# obtain the outputs of the net
outputs = net(inputs)
# calculate the loss between the target and outputs from the net
loss = criterion(outputs, labels)
print(loss, i, epoch)
# backward of the net
loss.backward()
optimizer.step()
#save the model each 100 epoch
if epoch % 100 == 99:
model_name = os.path.join('model/params_%d.pkl' % epoch)
torch.save(net.state_dict(), model_name)