在开源的 TfPyTh 中,不论是 TensorFlow 还是 PyTorch 计算图,它们都可以包装成一个可微函数,并在另一个框架中高效完成前向与反向传播。
github项目地址:BlackHC/TfPyTh。
神经网络交换格式 ONNX,定义了一种通用的计算图。目前 ONNX 已经原生支持 MXNet、PyTorch 和 Caffe2 等大多数框架,但是像 TensorFlow 或 Keras 之类的只能通过第三方转换器转换为 ONNX 格式。TfPyTh 无需改写已有的代码就能在框架间自由转换。TfPyTh 允许我们将 TensorFlow 计算图包装成一个可调用、可微分的简单函数,然后 PyTorch 就能直接调用它完成计算。反过来也是同样的,TensorFlow 也能直接调用转换后的 PyTorch 计算图。
目前 TfPyTh 主要支持三大方法:
torch_from_tensorflow:创建一个 PyTorch 可微函数,并给定 TensorFlow 占位符输入计算张量输出;
eager_tensorflow_from_torch:从 PyTorch 创建一个 Eager TensorFlow 函数;
tensorflow_from_torch:从 PyTorch 创建一个 TensorFlow 运算子或张量。TfPyTh 示例
import tensorflow as tf
import torch as th
import numpy as np
import tfpyth
session = tf.Session()
def get_torch_function():
a = tf.placeholder(tf.float32, name='a')
b = tf.placeholder(tf.float32, name='b')
c = 3 * a + 4 * b * b
f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply
return f
f = get_torch_function()
a = th.tensor(1, dtype=th.float32, requires_grad=True)
b = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a, b)assert x == 39.x.backward()assert np.allclose((a.grad, b.grad), (3., 24.))