作者信息:
Robin Brügger, CV Lab,ETH Zürich
代码:https://github.com/RobinBruegger/RevTorch
https://github.com/RobinBruegger/PartiallyReversibleUnet
医疗影像常用3D网络,显存占用经常制约了网络结构与深度,从而对最终精度产生影响。文章主要借鉴了reversible block 的思路来解决上述问题。
reversible block
该block设计很巧妙。输入x 按通道数先分成两组,x1, x2。利用如下公式(1),得到y1,y2,由于特殊的结构设计,x1,x2反过来又可以由公式(2) 通过y1,y2计算得到。
网络训练时显存占用很大一部分是储存前向传播的中间结果(因为反向传播时需要用到),使用 reversible block 后,中间结果无需保存,只要保存最后输出的结果,中间结果都可以反推得到。
Method
文章基于MICCAI Brats18挑战赛第二名 No-New-Net 的结构进行改进,引入reversible block后的网络结构如下:
Results
结果很好,第一二行比较可以看到使用reversible block后,显存节约2.5G,使得在12G显存下使用full volume 训练成为可能,与No-New-Net的单模型比也要强。
代码
reversible block模块部分的代码如下,反向传播的代码花了一定时间才大致了解。f.backward(dy)
是链式法则的意思:把f.backward()
得到的梯度乘上之前层反传得到的梯度dy
,可以参考这个资料
import torch
import torch.nn as nn
#import torch.autograd.function as func
class ReversibleBlock(nn.Module):
'''
Elementary building block for building