深度学习顶会论文 复现 tensorflow代码转pytorch 吐血入门经验

1. 前言

我想把一个TensorFlow代码转为pytorch代码,深度学习的代码。经过一个月的调试。。。。。自己好垃圾啊。。。

2.目标

将某tensorflow代码转pytorch。

3.过程

  1. 阅读需要复现的原文
  2. 很快啊,就一天可能就把TensorFlow的网络结构照猫画虎的写成了pytorch
  3. 然后就进入了无限调bug阶段。。。持续两周左右
  4. 最后想要放弃的时候,打算搭建TensorFlow的环境跑作者提供的代码,这时灵光一现,发现了自己哪里写错了。。。

4.建议的过程

网络结构或许很简单,但是你如果不清楚细节,直接照猫画虎,那100%的出错,后期花费的时间将是超级大的代价!!!

4.1 论文阶段

  1. 粗读论文!
  2. 粗读代码+精读论文。这一阶段一定要清晰的理解作者的思想,把代码和文章读通
  3. 精读代码。这时把代码一字一字的读,我师兄说,他之前把别人的代码打印出来,一字一句的看,看到心中有数才开始写代码

4.2 代码阶段

TensorFlow的很多函数和Pytorch的函数、默认参数、初始化很多不一样,需要注意每个函数

  1. 必须搭建作者建议的环境,要什么环境,搭建什么环境
  2. 跑通作者提供的代码
  3. 调试作者提供的代码,一般训练时间很长,输出sample的间隔会很长,设定两步输出一个结果,运行较少的次数之后停止。(可以初步的看到处理结果是理想的时候)这个结果保留下来与自己之后复现的对比
  4. 复现论文的网络结构(不要一次写完结构调试,调好一部分,然后再调一部分)
  5. 复现论文的loss
  6. 复现论文的数据处理增强导入的部分
  7. 调bug到,程序可以运行起来
  8. 运行代码之后,打印输出网络的结构。
  9. 对比网络结构是否复现正确。经过这个过程,我发现自己的最后一层conv层加了batch_norm,坏事了。。。
  10. 调试作者代码,打印构成loss的每一个部分的值,打印自己复现的loss,保证在一个数量级上。经过这一过程,我发现,一个loss写错了。
  11. 查看数据增强,数据处理,预训练模型的加载。TensorFlow和pytorch不一样,但是流程类似。这一步我发现,我用cv.imread读的是BGR,pytorch官方的VGG19模型用的PIL,是RGB。。。。另外预训练模型加载没有问题,但是我加载完之后,逐步调试发现,加载完预训练模型之后,进行了参数初始化。。。。
  12. 调试复现的代码,直到与步骤3中得到的前期运行结果类似,然后才可以运行完论文规定的迭代次数。这里对比前期的结果,基本就可以解决所有的bug,直到现在调试过程结束。
  13. 部署代码、评估复现的结果与原文的差距。
  14. 复现结束,你可以总结以下,其中需要注意的地方,然后将好的部分借鉴到自己的实验中,或者进一步改进模型。直到现在,你才可以着手改进模型!!!前期调试过程,可以记录改进的思路,但是不要改,力求与源代码的一致性。因为你调bug的时候,会不清楚是哪里出的问题!

5. 这一切的原因-分析

自己这段时间太急躁了!!!我可以静下心来看论文,但是一旦开始写代码,恨不得一天写完,写完了代码运行,感觉自己的干完了,但是写代码的过程中急躁是要不得的!!!写代码用了一天,调bug用了10天?事后回头看,总结一下自己的不足:

  1. 急躁
  2. 不按部就班的做,一步一步的来,不要没学会走,就要跨栏。但是因为跨栏太过诱人,你非要跨栏?突然想起来飞蛾扑火这个词。
  3. 代码复现不是科研!!是工程,所以要按照工程的思路来做,一步一步,走流程
  4. idea需要灵光一闪,但是写代码,一定要按部就班!!!
  5. 不要偷懒,最开始想的装TensorFlow的环境太费劲,然后就复现pytorch的代码,所以就有了这糟心的一个月
  6. 顶会的论文是有含金量的,在没有深入分析之后,切不可怀疑,复现就是复现!!!等你复现之后,在改动一些你认为可以提升性能的地方。就像英语翻译,老师一直教我们“信、达、雅”,保真是翻译的核心!!!

6.参数结构打印

TensorFlow1.12的打印结构:

for var in tf.trainable_variables():
    print("Listing trainable variables ... ")
    print(var)

TensorFlow1.12的打印参数:1

import tensorflow as tf

reader = tf.train.NewCheckpointReader('logs/pre-trained/lasted_model.ckpt')

global_variables = reader.get_variable_to_shape_map()
for key in global_variables:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

pytorch 打印结构:

net = resnet()#实例化网络
print(net)

pytorch 打印参数:2

for name, parameters in net.named_parameters():#打印出每一层的参数的大小
       print(name, ':', parameters.size())

for param_tensor in net.state_dict():  # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
        print(param_tensor, '\t', net.state_dict()[param_tensor])

python打印到文本中:3

f = open('text.txt', 'w')
print('abc', file = f, flush=False)
f.close()

  1. https://www.cnblogs.com/adong7639/p/7764769.html ↩︎

  2. https://blog.csdn.net/chunfeng0301/article/details/108227660 ↩︎

  3. https://www.zky.name/article/58.html ↩︎

评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

nachifur

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值