Pytorch Gradient Checkpoint使用示例

训练深度学习模型过程中,经常会遇到CUDA error: out of memory(OOM)的问题。有一些简单粗暴但不elegent的解决办法:

  • 减小Batch Size, e.g 32 → \rightarrow 16
  • 减小输入的大小,e.g.332 × \times × 332 × \times × 3 → \rightarrow 224 × \times × 224 × \times × 3
  • 换一块显存更大的GPU

查了一下,PyTorch提供了一种更优雅的解决方式gradient checkpoint(查了一下应该是0.4.0之后引入的新功能),以计算时间换内存的方式,显著减小模型训练对GPU的占用。在我的模型里,使用gradien checkpoint后,显存占用节省约30%。

虽然PyTorch的gradient checkpoint使用非常简单,但刚开始接触还是希望能有一些示例可以参考。网上找了好久才找到了一篇示例参考。这里给出一个UNet的使用示例,也把过程中遇到的问题和解决办法总结下来。

gradient checkpoint

PyTorch的gradient checkpoint是通过torch.utils.checkpoint.checkpoint(function, *args, **kwargs)函数实现的。
这里把PyTorch官方文档中关于该函数的介绍引用翻译如下:

Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

Gradient Checkpoint是通过以更长的计算时间为代价,换取更少的显存占用。相比于原本需要存储所有中间变量以供反向传播使用,使用了checkpoint的部分不存储中间变量而是在反向传播过程中重新计算这些中间变量。模型中的任何部分都可以使用gradient checkpoint。

gradient checkpoint使用示例

这里以UNet来演示如何使用gradient checkpoint.

import torch 
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.checkpoint import checkpoint

class in_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(in_conv, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        def forward(self, x):
        	x = self.op(x)
        	return x

class conv3x3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv3x3, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(in_ch,out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch,out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.op(x)
        return x
        
class down_block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down_block, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2)
        self.conv = conv3x3(in_ch, out_ch)
        
    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        return x

class up_block(nn.Module):
    def __init__(self, in_ch, out_ch, residual=False):
        super(up_block, self).__init__()
        
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.conv = conv3x3(in_ch, out_ch)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

        return x

class out_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(out_conv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        return x

class UNet(nn.Module):
    def __init__(self, img_channels, n_classes,use_checkpoint=False):
        super(UNet, self).__init__()
        self.inc = in_conv(img_channels,64)
        self.down1 = down_block(64, 128)
        self.down2 = down_block(128, 256)
        self.down3 = down_block(256, 512)
        self.down4 = down_block(512, 1024)
        self.up1 = up_block(1024, 512)
        self.up2 = up_block(512, 256)
        self.up3 = up_block(256, 128)
        self.up4 = up_block(128, 64)
        self.outc = out_conv(64, 1)

    def forward(self, x):
        def forward(self, x):
        x = Variable(x,requires_grad=True) 
        if self.use_checkpoint:
            x1 = checkpoint(self.inc,x)
            x2 = checkpoint(self.down1,x1)
            x3 = checkpoint(self.down2,x2)
            x4 = checkpoint(self.down3,x3)
            x5 = checkpoint(self.down4,x4)
            x = checkpoint(self.up1,x5,x4)
            x = checkpoint(self.up2,x,x3)
            x = checkpoint(self.up3,x,x2)
            x = checkpoint(self.up4,x,x1)
            x = checkpoint(self.outc,x)
        else:
        	x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            x = self.outc(x)

        return x

注意第94行,必须确保checkpoint的输入输出都声明为require_grad=True的Variable,否则运行时会报如下的错

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

GPU使用情况监控

推荐一个比nvidia-smi更好用的gpu监控方式——gpustat.
nvidia-smi如果要持续监控GPU使用情况的话,需要loop nvidia-smi,持续打印,不太美观也不易于查看。
gpustat可以动态监控,开一个tab运行就可以持续观测GPU使用情况的变化了。
安装方式

pip install gpustat

使用方式

watch  --color -n1 gpustat -cpu
  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: PyTorch Checkpoint是一种用于保存和恢复模型状态的工具。它可以在训练过程中定期保存模型的状态,以便在需要时恢复模型的状态。以下是PyTorch Checkpoint使用方法: 1. 导入必要的库: ``` import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torch.utils.data import DataLoader ``` 2. 定义模型: ``` class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = nn.Linear(10, 20) self.fc2 = nn.Linear(20, 2) def forward(self, x): x = self.fc1(x) x = nn.ReLU()(x) x = self.fc2(x) return x model = MyModel() ``` 3. 定义优化器和损失函数: ``` optimizer = optim.Adam(model.parameters(), lr=.001) criterion = nn.CrossEntropyLoss() ``` 4. 定义数据集和数据加载器: ``` train_dataset = MyDataset() train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) ``` 5. 定义训练循环: ``` for epoch in range(10): for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() if i % 100 == : checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item() } torch.save(checkpoint, 'checkpoint.pth') ``` 6. 定义恢复模型状态的函数: ``` def load_checkpoint(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] return model, optimizer, epoch, loss ``` 7. 使用恢复模型状态的函数恢复模型状态: ``` model, optimizer, epoch, loss = load_checkpoint('checkpoint.pth') ``` 以上就是PyTorch Checkpoint使用方法。 ### 回答2: PyTorch是一个开源的深度学习框架,一般用于训练神经网络以及其他深度学习模型。PyTorch提供了checkpoint这个类来进行模型训练时的状态保存和恢复。 checkpoint是一个类,需要导入torch.utils.checkpoint,通过这个类可以实现动态图模型的中间结果的保存。 当我们训练一个深度神经网络时,模型可能会非常大,可能需要几天或几周才能完成训练。为了避免在训练过程中出现问题,需要对模型中间结果进行保存。而PyTorchcheckpoint就可以实现这个功能。 checkpoint使用方法非常简单,可以在代码中使用下列方式进行: torch.utils.checkpoint.save(file_path, **kwargs) 其中,file_path是保存文件的路径,可以是绝对路径或相对路径,kwargs是用于保存的参数。 可以通过如下代码进行重载: torch.utils.checkpoint.load(file_path) 其中,file_path是要加载的checkpoint文件的路径。 checkpoint的具体使用方式可以在模型训练的时候进行调用,如下: for i, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() if (i + 1) % checkpoint_frequency == 0: checkpoint(model, optimizer, loss, i, file_path) 以上是checkpoint使用方法,可以有效保证模型训练过程中的结果,让深度学习工程师更加方便的管理和优化模型。 ### 回答3: PyTorch是一个非常流行的深度学习框架,它可以帮助构建和训练深度学习模型。PyTorch中提供了Checkpoint(检查点)功能,可以保存模型的状态,以便在训练期间或之后重新启动模型,并从上次离开的地方继续训练模型。本文将介绍PyTorch Checkpoint使用方法。 定义模型 在开始Checkpoint之前,需要首先定义模型,这包括模型的结构和超参数的设定。例如,我们可以使用以下代码定义一个简单的卷积神经网络: import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x 设定优化器和损失函数 接下来,需要定义模型的优化器和损失函数。例如,我们可以使用以下代码定义SGD优化器和交叉熵损失函数: net = Net() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 定义Checkpoint 接下来,我们需要定义一个Checkpoint,以便在训练过程中保存模型。以下是Checkpoint的定义方式: checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer' : optimizer.state_dict(), } 这里的checkpoint是一个Python字典,其中包含三个元素: epoch:表示当前训练的轮数,也就是模型训练到了哪个轮数; state_dict:表示模型的状态,其中包括所有的权重、偏置、梯度等; optimizer:表示优化器的状态,其中包括优化器的参数和状态。 保存Checkpoint 接下来,我们可以使用以下代码保存Checkpoint: torch.save(checkpoint, 'checkpoint.pth') 这里的checkpoint.pth是保存Checkpoint的文件名。我们可以把这个文件名命名为任何我们想要的名字。 恢复Checkpoint 当我们需要恢复Checkpoint时,可以使用以下代码: checkpoint = torch.load('checkpoint.pth') net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) 这里的checkpoint.pth是我们之前保存的文件名,我们将其加载到checkpoint变量中。然后,我们可以使用load_state_dict()函数将模型参数加载到我们的神经网络中,使用load_state_dict()函数将优化器状态加载到我们的优化器中。 使用Checkpoint 当我们恢复Checkpoint后,我们可以继续训练模型。以下是如何使用Checkpoint继续训练模型的示例代码: for epoch in range(start_epoch, END_EPOCH): train(epoch) checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer' : optimizer.state_dict(), } torch.save(checkpoint, 'checkpoint.pth') 在训练期间,我们可以保存多个Checkpoint,每个Checkpoint代表不同的训练状态。在保存Checkpoint时,我们可以指定保存Checkpoint的文件名。当需要恢复Checkpoint时,我们只需要将对应的文件名加载到checkpoint中即可。 综上所述,Checkpoint是一个很方便的工具,可以帮助我们在训练中保存模型的状态,以便之后恢复模型,并继续训练模型。在实际应用中,我们可以根据不同的需要,定制自己的Checkpoint保存策略。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值