1 网络迁移
基于torch实现的网络需要修改为基于tf的实现, 类似于翻译的工作, 主要是把torch的算子替换为tf等价的替换。
经常会遇到torch中的算子tf中没有, 或者虽然有但功能不等价的情况, 尤其是后一种情况需要格外注意。
还有一个主要区别是torch中通常定义一个类, 在__init__
中定义好需要用到的算子,在forward
中进行网络的连接;
但是tf通常是函数式编程, 虽然也可以定义类, 但是通常没必要, 因为tf中很多算子都是函数, 定义和计算是一起完成的。
比如在torch中:
class Net():
def __init__(self):
self.conv = torch.nn.Conv2d(...)
def forward(self, input):
out = self.conv(input)
return out
在__init__
先定义一个conv操作, 然后在forward
中使用。
但是在tf中:
out = tf.layers.conv2d(input, ...)
定义和计算是在一起的, 因为tf.layers.conv2d
本身只是个函数。
2 数据部分
数据需要用tf的方式进行读取。
3 运行部分
tf是静态图模式, 需要先定义计算图, 然后调用sess.run
运行计算图才能得到最终结果。
因为是静态图模式, 要想进行调试输出中间结果是比较困难的。 这里提供了一些小技巧,
可以参考:tensorflow调试小技巧