1. 前言
我想把一个TensorFlow代码转为pytorch代码,深度学习的代码。经过一个月的调试。。。。。自己好垃圾啊。。。
2.目标
将某tensorflow代码转pytorch。
3.过程
- 阅读需要复现的原文
- 很快啊,就一天可能就把TensorFlow的网络结构照猫画虎的写成了pytorch
- 然后就进入了无限调bug阶段。。。持续两周左右
- 最后想要放弃的时候,打算搭建TensorFlow的环境跑作者提供的代码,这时灵光一现,发现了自己哪里写错了。。。
4.建议的过程
网络结构或许很简单,但是你如果不清楚细节,直接照猫画虎,那100%的出错,后期花费的时间将是超级大的代价!!!
4.1 论文阶段
- 粗读论文!
- 粗读代码+精读论文。这一阶段一定要清晰的理解作者的思想,把代码和文章读通
- 精读代码。这时把代码一字一字的读,我师兄说,他之前把别人的代码打印出来,一字一句的看,看到心中有数才开始写代码
4.2 代码阶段
TensorFlow的很多函数和Pytorch的函数、默认参数、初始化很多不一样,需要注意每个函数
- 必须搭建作者建议的环境,要什么环境,搭建什么环境
- 跑通作者提供的代码
- 调试作者提供的代码,一般训练时间很长,输出sample的间隔会很长,设定两步输出一个结果,运行较少的次数之后停止。(可以初步的看到处理结果是理想的时候)这个结果保留下来与自己之后复现的对比
- 复现论文的网络结构(不要一次写完结构调试,调好一部分,然后再调一部分)
- 复现论文的loss
- 复现论文的数据处理增强导入的部分
- 调bug到,程序可以运行起来
- 运行代码之后,打印输出网络的结构。
- 对比网络结构是否复现正确。经过这个过程,我发现自己的最后一层conv层加了batch_norm,坏事了。。。
- 调试作者代码,打印构成loss的每一个部分的值,打印自己复现的loss,保证在一个数量级上。经过这一过程,我发现,一个loss写错了。
- 查看数据增强,数据处理,预训练模型的加载。TensorFlow和pytorch不一样,但是流程类似。这一步我发现,我用cv.imread读的是BGR,pytorch官方的VGG19模型用的PIL,是RGB。。。。另外预训练模型加载没有问题,但是我加载完之后,逐步调试发现,加载完预训练模型之后,进行了参数初始化。。。。
- 调试复现的代码,直到与步骤3中得到的前期运行结果类似,然后才可以运行完论文规定的迭代次数。这里对比前期的结果,基本就可以解决所有的bug,直到现在调试过程结束。
- 部署代码、评估复现的结果与原文的差距。
- 复现结束,你可以总结以下,其中需要注意的地方,然后将好的部分借鉴到自己的实验中,或者进一步改进模型。直到现在,你才可以着手改进模型!!!前期调试过程,可以记录改进的思路,但是不要改,力求与源代码的一致性。因为你调bug的时候,会不清楚是哪里出的问题!
5. 这一切的原因-分析
自己这段时间太急躁了!!!我可以静下心来看论文,但是一旦开始写代码,恨不得一天写完,写完了代码运行,感觉自己的干完了,但是写代码的过程中急躁是要不得的!!!写代码用了一天,调bug用了10天?事后回头看,总结一下自己的不足:
- 急躁
- 不按部就班的做,一步一步的来,不要没学会走,就要跨栏。但是因为跨栏太过诱人,你非要跨栏?突然想起来飞蛾扑火这个词。
- 代码复现不是科研!!是工程,所以要按照工程的思路来做,一步一步,走流程
- idea需要灵光一闪,但是写代码,一定要按部就班!!!
- 不要偷懒,最开始想的装TensorFlow的环境太费劲,然后就复现pytorch的代码,所以就有了这糟心的一个月
- 顶会的论文是有含金量的,在没有深入分析之后,切不可怀疑,复现就是复现!!!等你复现之后,在改动一些你认为可以提升性能的地方。就像英语翻译,老师一直教我们“信、达、雅”,保真是翻译的核心!!!
6.参数结构打印
TensorFlow1.12的打印结构:
for var in tf.trainable_variables():
print("Listing trainable variables ... ")
print(var)
TensorFlow1.12的打印参数:1
import tensorflow as tf
reader = tf.train.NewCheckpointReader('logs/pre-trained/lasted_model.ckpt')
global_variables = reader.get_variable_to_shape_map()
for key in global_variables:
print("tensor_name: ", key)
print(reader.get_tensor(key))
pytorch 打印结构:
net = resnet()#实例化网络
print(net)
pytorch 打印参数:2
for name, parameters in net.named_parameters():#打印出每一层的参数的大小
print(name, ':', parameters.size())
for param_tensor in net.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
print(param_tensor, '\t', net.state_dict()[param_tensor])
python打印到文本中:3
f = open('text.txt', 'w')
print('abc', file = f, flush=False)
f.close()