前面已经简单介绍了Tensorflow预训练模型如何导入Pytorch框架下核心方法。但写的比较笼统,估计对框架不熟悉的朋友可能还是云里雾里。鉴于此,本文以小米通过NAS搜索的超分网络FALSR为例进行介绍。该开源项目截屏如下所示,它只提供了pb文件以及测试脚本,无任何模型的介绍,对于这类开源模型我们如何进行导入呢?
![811cae612810083a81380607ba8909f7.png](https://i-blog.csdnimg.cn/blog_migrate/ef3d7dffe9e44ebb5eae287c025e22d7.jpeg)
首先,我们要做的第一步是确认FALSR的网络架构并通过Pytorch实现。那么如何确认FALSR的网路架构呢?对于该问题,我们可以通过tensorboard来查看它的图结构(附录中会提供保存该图的代码),并根据它的图数据处理流程来确认网络架构,并最终通过Pytorch完成网络架构的实现。FALSR-A的网络架构用Pytorch可以描述如下:
class Cell(nn.Module):
def __init__(self, in_channels, out_channels):
super(Cell, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.ReLU())
self.net = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.ReLU())
def forward(self, x):
x = self.conv0(x)
return x + self.net(x)
# converted from FALSR-A (xiaomi ailab)
class FALSRA(nn.Module):
def __init__(self):
super(FALSRA, self).__init__()
self.conv0 = nn.Conv2d(1, 32, 3, 1, 1)
self.cell0 = Cell(32, 64)
self.cell1 = nn.Sequential(nn.Conv2d(64, 48, 1, 1),
nn.ReLU())
self.cell2 = Cell(48, 64)
self.cell3 = Cell(112, 64)
self.cell4 = Cell(144, 64)
self.cell5 = nn.Sequential(nn.Conv2d(64, 64, 1, 1),
nn.ReLU(),
nn.Conv2d(64, 64, 1, 1),
nn.ReLU(),
nn.Conv2d(64, 64, 1, 1),
nn.ReLU(),
nn.Conv2d(64, 64, 1, 1),
nn.ReLU())
self.cell6 = Cell(208, 64)
self.conv1 = nn.Sequential(nn.Conv2d(464, 32, 3, 1, 1), nn.ReLU())
self.conv2 = nn.Conv2d(8, 1, 3, 1, 1)
# just used for ypbpr2rgb
self.conv3 = nn.Conv2d(3, 3, 1, 1)
def forward(self, x, pbpr):
conv0 = F.relu(self.conv0(x))
cell0 = self.cell0(conv0)
cell1 = self.cell1(cell0)
cell2 = self.cell2(cell1)
cell3 = self.cell3(torch.cat([cell1, cell2], dim=1))
cell4 = self.cell4(torch.cat([conv0, cell1, cell3], dim=1))
cell5 = self.cell5(cell4)
cell6 = self.cell6(torch.cat([conv0, cell1, cell4, cell5], dim=1))
conv1 = self.conv1(torch.cat([conv0, cell0, cell1, cell2, cell3, cell4, cell5, cell6], dim=1))
out = conv0 + conv1
out = out[:, [0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31], :,:]
out = F.pixel_shuffle(out, 2)
out = self.conv2(out) * 255
pbpr = (pbpr + 0.5) * 255
out = torch.cat([out, pbpr], dim=1)
out = self.conv3(out)
out = torch.clamp(out, 0, 255)
return out
注:也许有不少朋友会注意到59行代码中的处理方式,为什么要这样处理?事实上,这个问题我以前也没注意到,一直以为tensorflow中的depth_to_space与pytorch中的pixel_shuffle两个OP是完全一致的。在实际模型验证时才发现两者处理方式上细微差异,此处差异可自己揣摩一下,这个模块是超分中最基本的模块之一,大家转换模型时一定要注意!
其此,我们已经完成了FALSR-A模型的pytorch框架重写,那么下一步的工作则是将pb中有用的权重导出来,可以保存为pkl,也可以保存为npz等不同格式的中间文件。权值提取保存的代码参考如下:
import tensorflow as tf
import numpy as np
def tr(v):
# tensorflow weights to pytorch weights
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3,2,0,1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
pb_path = './pretrained_model/FALSR-A.pb'
def create_graph(modelpath):
with tf.gfile.FastGFile(modelpath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph(pb_path)
constant_values = dict()
with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
print(constant_op)
constant_values[constant_op.name] = tr(sess.run(constant_op.outputs[0]))
out = constant_values
np.savez('false-a.npz', out = out)
然后,前面已经完成了FALSR-A模型的参数导出,后面的工作则是将其导入到Pytorch框架中。细节不再介绍,直接上代码:
import torch
from model_A import FALSRA
def load_conv_parameter(model, index, data, key):
model.net[index].weight = torch.nn.Parameter(torch.from_numpy(data[key + '/kernel']))
if hasattr(model.net[index], 'bias') and model.net[index].bias is not None:
model.net[index].bias = torch.nn.Parameter(torch.from_numpy(data[key + '/bias']))
return model
import numpy as np
data = np.load('zoom.npz')
out = data['out'][()]
## load
model = FALSRA()
model.conv0.weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/c/kernel']))
model.conv0.bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/c/bias']))
model.conv1[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/2/0/kernel']))
model.conv1[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/2/0/bias']))
model.conv2.weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/yout/kernel']))
model.conv2.bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/yout/bias']))
model.conv3.weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/Const']))
model.conv3.bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/Const_1']))
model.cell0.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell0.conv0[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b0/bias']))
model.cell0.net[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell0.net[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b1/bias']))
model.cell0.net[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell0.net[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b2/bias']))
model.cell0.net[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell0.net[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b3/bias']))
model.cell1[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell1/conv_f48_k1_b1_isskip_b0/kernel']))
model.cell1[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell1/conv_f48_k1_b1_isskip_b0/bias']))
model.cell2.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell2.conv0[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b0/bias']))
model.cell2.net[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell2.net[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b1/bias']))
model.cell2.net[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell2.net[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b2/bias']))
model.cell2.net[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell2.net[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b3/bias']))
model.cell3.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell3.conv0[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b0/bias']))
model.cell3.net[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell3.net[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b1/bias']))
model.cell3.net[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell3.net[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b2/bias']))
model.cell3.net[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell3.net[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b3/bias']))
model.cell4.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell4.conv0[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b0/bias']))
model.cell4.net[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell4.net[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b1/bias']))
model.cell4.net[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell4.net[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b2/bias']))
model.cell4.net[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell4.net[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b3/bias']))
model.cell5[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b0/kernel']))
model.cell5[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b0/bias']))
model.cell5[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b1/kernel']))
model.cell5[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b1/bias']))
model.cell5[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b2/kernel']))
model.cell5[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b2/bias']))
model.cell5[6].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b3/kernel']))
model.cell5[6].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b3/bias']))
model.cell6.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell6.conv0[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b0/bias']))
model.cell6.net[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell6.net[0].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b1/bias']))
model.cell6.net[2].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell6.net[2].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b2/bias']))
model.cell6.net[4].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell6.net[4].bias = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b3/bias']))
torch.save(model.state_dict(), './FALSR-A-torch.pth.tar')
最后,前面已经完成了TF模型转Pytorch,现在最重要的就是验证了。话不多说,直接上代码:
def main():
lr_path = './dataset/Set5/img_001_SRF_2_LR.png'
calculate_lr_img = scipy.misc.imread(lr_path, mode='RGB')
ypbprt = sc.rgb2ypbpr(calculate_lr_img / 255.0)[..., 0][np.newaxis, ...][np.newaxis]
scale = 2
size = calculate_lr_img.shape
x_scale = scipy.misc.imresize(calculate_lr_img, [size[0] * scale, size[1] * scale], interp='bicubic', mode=None)
pbpr = sc.rgb2ypbpr(x_scale / 255)[..., 1:]
pbpr = np.transpose(pbpr, (2, 0,1))[np.newaxis, ...]
pbpr = torch.from_numpy(pbpr.astype(np.float32))
model = FALSRA()
model.load_state_dict(torch.load('FALSR-A-torch.pth.tar'))
model.eval()
inputs = torch.from_numpy(ypbprt.astype(np.float32))
with torch.no_grad():
output = model(inputs, pbpr)
pred = output[0].detach().numpy().astype(np.uint8)
pred = np.transpose(pred, (1,2,0))
cv2.imshow('pred', pred[:,:,::-1])
cv2.waitKey()
cv2.destroyAllWindows()
OK,终于有一篇流水账式的记录完成。事实上也没有好详细介绍的,看代码就完全懂了。模型转换关键性的东西已经全部呈现,到此结束,祝君好运。
附录
这里附上如何将pb文件中的图导出来并通过tensorboard查看。同样直接上代码:
import tensorflow as tf
model = 'FALSR-A.pb'
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)
通过上述代码,可以在log文件中看一个关于图的文件。注:不同人运行后文件名会不一致。
![9e8cc22699af29dad3979da4bdc3a891.png](https://i-blog.csdnimg.cn/blog_migrate/1586aaccc023abf9dbe29af2fa14a7cf.png)
最后通过tensorboard --logdir=./log
查看FALSR-A的图,附图如下。知道了FALSR-A的图结构后,剩下的就是根据该图在Pytorch框架下重写网络架构了。祝君好运。
![2f2c9289043eb7a0fa1218708958aa5b.png](https://i-blog.csdnimg.cn/blog_migrate/1421042123423f1d08d78a1be6e1c64d.jpeg)