周末在家无聊,想了想觉得有必要整理一下如何将其他框架(如Tensorflow, MxNet, Caffe, Keras等)训练的模型导入到Pytorch框架中。
也许有人会问:既然有现成的模型为啥还转成pytorch?直接用不可以吗?
也许有人会说:直接重新训练一个就可以了吗?
也许有人会说:直接用onnx。
也许...
人生苦短,我选择pytorch,不想浪费过多时间去学习各种各样的深度学习框架。
Tensorflow用起来太复杂,caffe编译太麻烦,加层也是相当麻烦......
=======================================================
github上也许有一些keras转pytorch,caffe转pytorch的代码,甚至据说onnx也可以。但事实上,我曾试过诸多各种方法,但没有哪一种比较通用,部分仅适用于个别模型,有的甚至都无法运行,浪费了大量的时间。
最终的最终,还是走上了纯手工预训练模型导入的方式。
先简单说一下背景,本人对tensorflow,keras,caffe,mxnet,pytorch,chainer这些框架都曾有过尝试。但是最终选择了pytorch作为“本命”框架,因为它的简洁、简单、透明等等优点,简单到你会用numpy就可以了。而其他框架呢,呵呵....
目前CV领域,各种各样的开源模型,有的论文中效果宣称如何如何,但就是不开源;有的论文开源了,但用了一个你不熟悉的框架;有的甚至只给你提供了pb文件;......
上述现象严重的阻碍了新算法的学习与复现。重新训练一遍会很耗时间,且不一定能达到论文的效果,不熟悉的框架同样会浪费诸多尝试的时间。
有感于上述现象阻碍,本人尝试将其他框架预训练模型导入到pytorch中。曾转换的模型包含但不限于以下方法:
(1) [VSR_DUF](https://github.com/yhjo09/VSR-DUF)
(2) [FALSR](https://github.com/falsr/FALSR)
(3) [AmoebaNet](https://github.com/tensorflow/tpu/tree/master/models/official/amoeba_net)
(4) [EfficientNet](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
(5) [MixNet](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet)
(6) [Speech2Video](https://github.com/joonson/yousaidthat)
(7) [ZoomLearnZoom](https://github.com/ceciliavision/zoom-learn-zoom)
(8) ......
言归正传,既然是纯手工导入方法,那么就必须对待导入的模型有一定的认知了解。
事实上,纯手工导入模型要求比较简单,简单到只需要下面三个步骤即可:
- 1. 用pytorch构建待加载模型
- 2. 将预训练模型中的必要权值导出
- 3. 将导出权值赋给pytorch模型中的对应参数即可
对于CNN而言,常见的操作主要包含:卷积(2d卷积、深度分离卷积,组卷积,转置卷积)、BatchNorm、激活(ReLU, PReLU, LeaklyReLU, Swish等), 全连接层等。
在pytorch中卷积描述为:
conv = nn.Conv2d(inc, outc, kernel, stride, padding, bias)
其参数则应按照如下方式查看:
print(conv.weight)
print(conv.bias)
权值参数尺寸排布方式为:outc X inc X kernel X kernel
而tensorflow中权值的排布方式则为:kernel X kernel X inc X outc
所以在将tensorflow的权值导入到pytorch框架中时应当transpose一下!!!
另外,需要注意的一点是:深度分裂卷积比较特殊。该特殊之处自己尝试一下即可明白,^_^.
关于BatchNorm需要导入的参数包含:weight, bias, running-mean, running-var.
他们分别对应tensorflow中的:gamma,beta,mean,variance。
此处参数导入没有什么难度,直接导入即可,因为参数维度都是一样的。
另外一个需要关注的是全连接层,事实上,清除了如何导入卷积层后,查看一下pytorch与tensorflow中的全连接层参数后就全懂了。这里不过多介绍。
最后,加载转换时方式为:conv.weight = nn.parameter(torch.from_numpy(tfweight))
其他参数于此类似,不在过多介绍。
流水账形式的模型转换渔具介绍到此结束,本想附上一个转换示例,结果发现个人电脑上午相关demo。明天上班后将FALSR的转换示例开源出来,感兴趣者敬请关注。