torch.utils.checkpoint 简介 和 简易使用

5个月之前说过要写个的,不过因为觉得过于简单,和懒,就鸽了。

pytorch 的 checkpoint 是一种用时间换显存的技术,一般训练模式下,pytorch 每次运算后会保留一些中间变量用于求导,而使用 checkpoint 的函数,则不会保留中间变量,中间变量会在求导时再计算一次,因此减少了显存占用,跟 tensorflow 的 checkpoint 是完全不同的东西。

这个 checkpoint 用的好的话,训练时相比不使用 checkpoint 的模型可以增加 30% 的批量大小。
不过对非训练模式(torch.no_grad)没有用,因为非训练模式不需要求导,也不会有中间变量产生。

使用 checkpoint 增大批量与 多次 backward 增大批量的方法相比,checkpoint 显然效果更好,因为batchnorm 可以更加准确,并且相同批量大小下 checkpoint 的速度比 多次 backward 的速度更快,一般情况下优选 checkpoint

以下用一个简单的 cifar10 分类实验 来测试 checkpoint 的功能

测试环境 win10 1903, gtx970m-3g
但因为是windows系统,实际只能用2.5G显存。。。
开启 checkpoint 的情况下,我的机器支持的最大批量大小是31
关闭 checkpoint 的情况下,我的机器支持的最大批量大小只有23
因为cifar10的图像分辨率太小了,吃显存也少,所以我把分辨率图像分辨率直接上采样到224x224来吃一堆显存

下列代码
第81行选择是否使用 checkpoint
第96行选择 batch_size

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torchvision.datasets.cifar import CIFAR10
import numpy as np
from progressbar import progressbar


def conv_bn_relu(in_ch, out_ch, ker_sz, stride, pad):
    return nn.Sequential(nn.Conv2d(in_ch, out_ch, ker_sz, stride, pad, bias=False),
                         nn.BatchNorm2d(out_ch),
                         nn.ReLU())


class NetA(nn.Module):
    def __init__(self, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint

        k = 2
        # 32x32
        self.layer1 = conv_bn_relu(3, 32*k, 3, 1, 1)
        self.layer2 = conv_bn_relu(32*k, 32*k, 3, 2, 1)
        # 16x16
        self.layer3 = conv_bn_relu(32*k, 64*k, 3, 1, 1)
        self.layer4 = conv_bn_relu(64*k, 64*k, 3, 2, 1)
        # 8x8
        self.layer5 = conv_bn_relu(64*k, 128*k, 3, 1, 1)
        self.layer6 = conv_bn_relu(128*k, 128*k, 3, 2, 1)
        # 4x4
        self.layer7 = conv_bn_relu(128*k, 256*k, 3, 1, 1)
        self.layer8 = conv_bn_relu(256*k, 256*k, 3, 2, 1)
        # 1x1
        self.layer9 = nn.Linear(256*k, 10)

    def seg0(self, y):
        y = self.layer1(y)
        return y

    def seg1(self, y):
        y = self.layer2(y)
        y = self.layer3(y)
        return y

    def seg2(self, y):
        y = self.layer4(y)
        y = self.layer5(y)
        return y

    def seg3(self, y):
        y = self.layer6(y)
        y = self.layer7(y)
        return y

    def seg4(self, y):
        y = self.layer8(y)
        y = F.adaptive_avg_pool2d(y, 1)
        y = torch.flatten(y, 1)
        y = self.layer9(y)
        return y

    def forward(self, x):
        y = x
        # y = self.layer1(y)
        y = y + torch.zeros(1, dtype=y.dtype, device=y.device, requires_grad=True)
        if self.use_checkpoint:
            # 使用 checkpoint
            y = checkpoint(self.seg0, y)
            y = checkpoint(self.seg1, y)
            y = checkpoint(self.seg2, y)
            y = checkpoint(self.seg3, y)
            y = checkpoint(self.seg4, y)
        else:
            # 不使用 checkpoint
            y = self.seg0(y)
            y = self.seg1(y)
            y = self.seg2(y)
            y = self.seg3(y)
            y = self.seg4(y)

        return y


if __name__ == '__main__':
    net = NetA(use_checkpoint=True).cuda()

    train_dataset = CIFAR10('../datasets/cifar10', True, download=True)
    train_x = np.asarray(train_dataset.data, np.uint8)
    train_y = np.asarray(train_dataset.targets, np.int)

    losser = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(net.parameters(), 1e-3)

    epoch = 10
    batch_size = 31
    batch_count = int(np.ceil(len(train_x) / batch_size))

    for e_id in range(epoch):
        print('epoch', e_id)

        print('training')
        net.train()
        loss_sum = 0
        for b_id in progressbar(range(batch_count)):
            optim.zero_grad()

            batch_x = train_x[batch_size*b_id: batch_size*(b_id+1)]
            batch_y = train_y[batch_size*b_id: batch_size*(b_id+1)]

            batch_x =  torch.from_numpy(batch_x).permute(0, 3, 1, 2).float() / 255.
            batch_y =  torch.from_numpy(batch_y).long()

            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()

            batch_x = F.interpolate(batch_x, (224, 224), mode='bilinear')

            y = net(batch_x)
            loss = losser(y, batch_y)
            loss.backward()
            optim.step()
            loss_sum += loss.item()
        print('loss', loss_sum / batch_count)

        with torch.no_grad():
            print('testing')
            net.eval()
            acc_sum = 0
            for b_id in progressbar(range(batch_count)):
                optim.zero_grad()

                batch_x = train_x[batch_size * b_id: batch_size * (b_id + 1)]
                batch_y = train_y[batch_size * b_id: batch_size * (b_id + 1)]

                batch_x = torch.from_numpy(batch_x).permute(0, 3, 1, 2).float() / 255.
                batch_y = torch.from_numpy(batch_y).long()

                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()

                batch_x = F.interpolate(batch_x, (224, 224), mode='bilinear')

                y = net(batch_x)

                y = torch.topk(y, 1, dim=1).indices
                y = y[:, 0]

                acc = (y == batch_y).float().sum() / len(batch_x)

                acc_sum += acc.item()
            print('acc', acc_sum / batch_count)

        ids = np.arange(len(train_x))
        np.random.shuffle(ids)
        train_x = train_x[ids]
        train_y = train_y[ids]

  • 26
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 39
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 39
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值