在训练比较大、耗时较久的网络时,如果突然停电、断网或者一些意外情况发生导致训练中断,那么已经训练好的内容可能全部丢失,这时我们就需要在训练过程中把一些时间点的checkpoint保存下来,及时训练意外中断,那么我们也可以在之后把这些checkpoint下载下来,重新开始训练。
(谁能想到我刚刚码好这段话就停电了呢????)

以下内容大部分和
cifar-10+resnet.
一样,重点在load_state_dict的,可以直接跳转:
戳这里↓
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
trans = transforms.Compose((transforms.Resize(32),transforms.ToTensor()))
cifar_train = datasets.CIFAR10('cifar',train = True,transform=trans)
cifar_train_batch = DataLoader(cifar_train,batch_size=30,shuffle=True)
cifar_test = datasets.CIFAR10('cifar',train = False,transform=trans)
cifar_test_batch = DataLoader(cifar_test,batch_size=30,shuffle=True)
#搭建resnet
class resblock(nn.Module):
def __init__(self,ch_in,ch_out,stride):
super(resblock,self).__init__()
self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn_1 = nn.BatchNorm2d(ch_out)
self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn_2 = nn.BatchNorm2d(ch_out)
self.ch_in,self.ch_out,self.stride = ch_in,ch_out,stride
self.ch_trans = nn.Sequential()
if ch_in != ch_out

本文介绍了在PyTorch中如何使用load_state_dict来保存和恢复模型的训练进度。通过保存checkpoint,可以防止因意外中断导致的训练损失。以CIFAR-10数据集和ResNet模型为例,展示了如何在已有的训练基础上继续训练,从而提高模型的准确率。
最低0.47元/天 解锁文章
2594

被折叠的 条评论
为什么被折叠?



