torch程序转tf指南

1 网络迁移

基于torch实现的网络需要修改为基于tf的实现, 类似于翻译的工作, 主要是把torch的算子替换为tf等价的替换。
经常会遇到torch中的算子tf中没有, 或者虽然有但功能不等价的情况, 尤其是后一种情况需要格外注意。
还有一个主要区别是torch中通常定义一个类, 在__init__中定义好需要用到的算子,在forward中进行网络的连接;
但是tf通常是函数式编程, 虽然也可以定义类, 但是通常没必要, 因为tf中很多算子都是函数, 定义和计算是一起完成的。
比如在torch中:

class Net():
    def __init__(self):
        self.conv = torch.nn.Conv2d(...)
    
    def forward(self, input):
        out = self.conv(input)
        return out

__init__先定义一个conv操作, 然后在forward中使用。

但是在tf中:

    out = tf.layers.conv2d(input, ...)

定义和计算是在一起的, 因为tf.layers.conv2d本身只是个函数。

2 数据部分

数据需要用tf的方式进行读取。

3 运行部分

tf是静态图模式, 需要先定义计算图, 然后调用sess.run运行计算图才能得到最终结果。
因为是静态图模式, 要想进行调试输出中间结果是比较困难的。 这里提供了一些小技巧,
可以参考:tensorflow调试小技巧

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值