前言
前面四篇文章,介绍了模型搭建、数据准备及pytorch中常用的计算方法等,有了上述基础后就可以训练模型了,下面这篇文章会简单介绍下在pytorch框架下如何训练深度学习模型,以及一些常用代码。
![fab7c7dca755516e8db284c2b317e837.png](https://i-blog.csdnimg.cn/blog_migrate/c309dd553d6dd37bb8ab58d722c99a5e.jpeg)
模型训练
以一个简单的分类模型为例,代码如下:
# 损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 逐步迭代,训练模型total_step = len(train_loader)for epoch in range(num_epochs): for i ,(images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
自定义损失函数
继承torch.nn.Module类,然后写自己定义的损失函数。
class MyLoss(torch.nn.Moudle): def __init__(self): super(MyLoss, self).__init__() def forward(self, x, y): Loss = torch.mean((x - y) ** 2) return Loss
L1 正则化
l1_regularization = torch.nn.L1Loss(reduction='sum')loss = ... # Standard cross-entropy lossfor param in model.parameters(): loss += torch.sum(torch.abs(param))loss.backward()
不对偏置项(b)进行权重衰减(weight decay)
pytorch框架中,weight decay相当于l2正则化。
bias_list = (param for name, param in model.named_parameters() if name[-4:] == 'bias')others_list = (param for name, param in model.named_parameters() if name[-4:] != 'bias')parameters = [{'parameters': bias_list, 'weight_decay': 0}, {'parameters': others_list}]optimizer = torch.optim.SGD(parameters, lr=1e-2, momentum=0.9, weight_decay=1e-4)
梯度裁剪(gradient clipping)
在深度学习的模型训练中,为了避免梯度爆炸问题,可以在反向传播过程中裁减梯度,从而保证梯度不会超过阈值。在pytorch框架下,一句话就可以解决:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20)
如何得到当前学习率
# If there is one global learning rate (which is the common case).lr = next(iter(optimizer.param_groups))['lr']# If there are multiple learning rates for different layers.all_lr = []for param_group in optimizer.param_groups: all_lr.append(param_group['lr'])
未完待续...
![e45c16dd8980031f3eec5f7a11f5e612.png](https://i-blog.csdnimg.cn/blog_migrate/b7ea33d867f47c31f3be32beb47dd541.jpeg)