从 SRGAN(TensorFlow) 导出生成网络(generator)训练参数数据

按《tensorflow2caffe(2) : 如何在tensorflow中取出模型参数》一文的代码原理:
把下面的代码放到 main.py 的 generator 部分:

        #-------------------------------------------------------------

        # 这里应该是global_variables,如果trainable_variables,则会缺少一些参数
        # all_vars = tf.trainable_variables()

        all_vars = tf.global_variables()
        fp = open('SRGAN_generator_model.txt', 'w')
        for v in all_vars:

            name = v.name

            fname = name + '.prototxt'

            fname = fname.replace('/','_')

            print (fname)
            fp.write(fname)
            fp.write('\n')

            v_4d = np.array(sess.run(v))
            if v_4d.ndim == 4:

                #v_4d.shape [ H, W, I, O ]        

                v_4d = np.swapaxes(v_4d, 0, 2) # swap H, I

                v_4d = np.swapaxes(v_4d, 1, 3) # swap W, O

                v_4d = np.swapaxes(v_4d, 0, 1) # swap I, O

                #v_4d.shape [ O, I, H, W ]


                vshape = v_4d.shape[:]

                v_1d = v_4d.reshape(v_4d.shape[0]*v_4d.shape[1]*v_4d.shape[2]*v_4d.shape[3])

                fp.write('  blobs {\n')

                for vv in v_1d:

                    fp.write('    data: %8f' % vv)

                    fp.write('\n')

                fp.write('    shape {\n')

                for s in vshape:

                    fp.write('      dim: ' + str(s))#print dims

                    fp.write('\n')

                fp.write('    }\n')

                fp.write('  }\n')
            elif v_4d.ndim == 1 :#do not swap


                fp.write('  blobs {\n')

                for vv in v_4d:

                    fp.write('    data: %.8f' % vv)

                    fp.write('\n')

                fp.write('    shape {\n')

                fp.write('      dim: ' + str(v_4d.shape[0]))#print dims

                fp.write('\n')

                fp.write('    }\n')

                fp.write('  }\n')

        fp.close()
        #-------------------------------------------------------------

然后运行就导出了一个文本方式的数据

SRGAN_generator_model.txt:

generator_generator_unit_input_stage_conv_Conv_weights:0.prototxt
  blobs {
    data: -0.022789
    data: -0.008191
    data: -0.001650
    ...省略
    data: 0.007882
    data: 0.007484
    shape {
      dim: 64
      dim: 3
      dim: 9
      dim: 9
    }
  }
generator_generator_unit_input_stage_conv_Conv_biases:0.prototxt
  blobs {
    data: -0.09940426
    ...省略
    data: -0.06667865
    shape {
      dim: 64
    }
  }
generator_generator_unit_input_stage_Prelu_alpha:0.prototxt
  blobs {
    ...省略
    shape {
      dim: 64
    }
  }
generator_generator_unit_resblock_1_conv_1_Conv_weights:0.prototxt
  blobs {
    ...省略
    shape {
      dim: 64
      dim: 64
      dim: 3
      dim: 3
    }
  }
generator_generator_unit_resblock_1_BatchNorm_beta:0.prototxt
...省略

这样就可以和 caffe_srgan-master中的数据对比一下异同。

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值