PRNet将tf转换为pytorch
tf模型可以很方便的转化为pytorch模型吗?可以。
我们需要做哪些步骤呢?
实现方法
我们以PRNet为例,实现tf2torch
- 加载tf模型
from predictor import resfcn256 # import from prnet_dir, maybe using importlib is a better idea
tf_network_def = resfcn256(256, 256)
# tensorflow network forward
net_input = tf.placeholder(
tf.float32, shape=[None, 256, 256, 3])
tf_model = tf_network_def(net_input, is_training=False)
tf_config = tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=False))
tf_config = tf.ConfigProto(
device_count={"GPU":0}
)
sess = tf.Session(config=tf_config)
saver = tf.train.Saver(tf_network_def.vars)
saver.restore(
sess, os.path.join(args.prnet_dir, 'Data', 'net-data', '256_256_resfcn256_weight'))
graph = sess.graph
#print([node.name for node in graph.as_graph_def().node])
- 类似tf的网络结构,定义torch的网络结构,并且模型中添加tf_map用于连接tf节点和torch节点的参数命名的不同。记为PRNet_full(此部分略去)
- 将tf模型中的参数迁移到torch中:
torch_dict = OrderedDict()
for node in graph.as_graph_def().node:
if node.name in torch_model.tf_map:
torch_name = torch_model.tf_map[node.name]
data = graph.get_operation_by_name(node.name).outputs[0]
data_np = sess.run(data)
if len(data_np.shape) > 1:
# weight layouts | tensorflow | pytorch | transpose |
# conv2d_transpose (H, W, out, in) -> (in, out, H, W) (3, 2, 0, 1)
# conv2d (H, W, in, out) -> (out, in, H, W) (3, 2, 0, 1)
torch_dict[torch_name] = torch.tensor(np.transpose(data_np, (3, 2, 0, 1)).astype(np.float32))
else:
torch_dict[torch_name] = torch.tensor(data_np.astype(np.float32))
else:
if node.name.find('save') == -1:
pass
# print('not in {}'.format(node.name))
torch.save(torch_dict, 'from_tf.pth')
- [可选] 将PRNet_full中那些用于连接的tf_map删除,重新建立一个干净的PRNet,load刚刚得到torch权重。
- [可选] 验证对于同样的输入,是否tf和torch得到的结果大致相同。
训练
也可以比较tf和torch训练每一步中是否loss相近。
完结撒花
参考
https://github.com/liguohao96/pytorch-prnet