python中如何导入torch_纯干货:如何将其他框架训练模型导入到pytorch中?(二)...

​ 前面已经简单介绍了Tensorflow预训练模型如何导入Pytorch框架下核心方法。但写的比较笼统,估计对框架不熟悉的朋友可能还是云里雾里。鉴于此,本文以小米通过NAS搜索的超分网络FALSR为例进行介绍。该开源项目截屏如下所示,它只提供了pb文件以及测试脚本,无任何模型的介绍,对于这类开源模型我们如何进行导入呢?

811cae612810083a81380607ba8909f7.png

​ 首先,我们要做的第一步是确认FALSR的网络架构并通过Pytorch实现。那么如何确认FALSR的网路架构呢?对于该问题,我们可以通过tensorboard来查看它的图结构(附录中会提供保存该图的代码),并根据它的图数据处理流程来确认网络架构,并最终通过Pytorch完成网络架构的实现。FALSR-A的网络架构用Pytorch可以描述如下:

class Cell(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Cell, self).__init__()
        self.conv0 = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                nn.ReLU())

        self.net = nn.Sequential(
                nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                nn.ReLU())

    def forward(self, x):
        x = self.conv0(x)

        return x + self.net(x)

# converted from FALSR-A (xiaomi ailab)
class FALSRA(nn.Module):
    def __init__(self):
        super(FALSRA, self).__init__()
        self.conv0 = nn.Conv2d(1, 32, 3, 1, 1)
        self.cell0 = Cell(32, 64)
        self.cell1 = nn.Sequential(nn.Conv2d(64, 48, 1, 1),
                                   nn.ReLU())
        self.cell2 = Cell(48, 64)
        self.cell3 = Cell(112, 64)
        self.cell4 = Cell(144, 64)
        self.cell5 = nn.Sequential(nn.Conv2d(64, 64, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, 1, 1),
                                   nn.ReLU())
        self.cell6 = Cell(208, 64)
        self.conv1 = nn.Sequential(nn.Conv2d(464, 32, 3, 1, 1), nn.ReLU())
        self.conv2 = nn.Conv2d(8, 1, 3, 1, 1)

        # just used for ypbpr2rgb
        self.conv3 = nn.Conv2d(3, 3, 1, 1)

    def forward(self, x, pbpr):
        conv0 = F.relu(self.conv0(x))
        cell0 = self.cell0(conv0)
        cell1 = self.cell1(cell0)
        cell2 = self.cell2(cell1)
        cell3 = self.cell3(torch.cat([cell1, cell2], dim=1))
        cell4 = self.cell4(torch.cat([conv0, cell1, cell3], dim=1))
        cell5 = self.cell5(cell4)
        cell6 = self.cell6(torch.cat([conv0, cell1, cell4, cell5], dim=1))
        conv1 = self.conv1(torch.cat([conv0, cell0, cell1, cell2, cell3, cell4, cell5, cell6], dim=1))

        out   = conv0 + conv1
        out   = out[:, [0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31], :,:]
        out   = F.pixel_shuffle(out, 2)
        out   = self.conv2(out) * 255

        pbpr  = (pbpr + 0.5) * 255
        out   = torch.cat([out, pbpr], dim=1)
        out   = self.conv3(out)
        out   = torch.clamp(out, 0, 255)

        return out

注:也许有不少朋友会注意到59行代码中的处理方式,为什么要这样处理?事实上,这个问题我以前也没注意到,一直以为tensorflow中的depth_to_space与pytorch中的pixel_shuffle两个OP是完全一致的。在实际模型验证时才发现两者处理方式上细微差异,此处差异可自己揣摩一下,这个模块是超分中最基本的模块之一,大家转换模型时一定要注意!

​ 其此,我们已经完成了FALSR-A模型的pytorch框架重写,那么下一步的工作则是将pb中有用的权重导出来,可以保存为pkl,也可以保存为npz等不同格式的中间文件。权值提取保存的代码参考如下:

import tensorflow as tf
import numpy as np

def tr(v):
    # tensorflow weights to pytorch weights
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v

pb_path = './pretrained_model/FALSR-A.pb'

def create_graph(modelpath):
    with tf.gfile.FastGFile(modelpath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

create_graph(pb_path)

constant_values = dict()

with tf.Session() as sess:
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
    for constant_op in constant_ops:
        print(constant_op)
        constant_values[constant_op.name] = tr(sess.run(constant_op.outputs[0]))



out = constant_values
np.savez('false-a.npz', out = out)

​ 然后,前面已经完成了FALSR-A模型的参数导出,后面的工作则是将其导入到Pytorch框架中。细节不再介绍,直接上代码:

import torch
from model_A import FALSRA

def load_conv_parameter(model, index, data, key):
    model.net[index].weight = torch.nn.Parameter(torch.from_numpy(data[key + '/kernel']))
    if hasattr(model.net[index], 'bias') and model.net[index].bias is not None:
        model.net[index].bias = torch.nn.Parameter(torch.from_numpy(data[key + '/bias']))

    return model

import numpy as np
data = np.load('zoom.npz')
out = data['out'][()]

## load
model = FALSRA()

model.conv0.weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/c/kernel']))
model.conv0.bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/c/bias']))

model.conv1[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/2/0/kernel']))
model.conv1[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/n32s1/2/0/bias']))

model.conv2.weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/yout/kernel']))
model.conv2.bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/yout/bias']))

model.conv3.weight          = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/Const']))
model.conv3.bias            = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/Const_1']))

model.cell0.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell0.conv0[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b0/bias']))
model.cell0.net[0].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell0.net[0].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b1/bias']))
model.cell0.net[2].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell0.net[2].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b2/bias']))
model.cell0.net[4].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell0.net[4].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell0/conv_f64_k3_b4_isskip_b3/bias']))

model.cell1[0].weight       = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell1/conv_f48_k1_b1_isskip_b0/kernel']))
model.cell1[0].bias         = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell1/conv_f48_k1_b1_isskip_b0/bias']))

model.cell2.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell2.conv0[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b0/bias']))
model.cell2.net[0].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell2.net[0].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b1/bias']))
model.cell2.net[2].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell2.net[2].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b2/bias']))
model.cell2.net[4].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell2.net[4].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell2/conv_f64_k3_b4_isskip_b3/bias']))

model.cell3.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell3.conv0[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b0/bias']))
model.cell3.net[0].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell3.net[0].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b1/bias']))
model.cell3.net[2].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell3.net[2].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b2/bias']))
model.cell3.net[4].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell3.net[4].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell3/conv_f64_k3_b4_isskip_b3/bias']))

model.cell4.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell4.conv0[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b0/bias']))
model.cell4.net[0].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell4.net[0].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b1/bias']))
model.cell4.net[2].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell4.net[2].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b2/bias']))
model.cell4.net[4].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell4.net[4].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell4/conv_f64_k3_b4_isskip_b3/bias']))

model.cell5[0].weight       = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b0/kernel']))
model.cell5[0].bias         = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b0/bias']))
model.cell5[2].weight       = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b1/kernel']))
model.cell5[2].bias         = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b1/bias']))
model.cell5[4].weight       = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b2/kernel']))
model.cell5[4].bias         = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b2/bias']))
model.cell5[6].weight       = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b3/kernel']))
model.cell5[6].bias         = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell5/conv_f64_k1_b4_noskip_b3/bias']))

model.cell6.conv0[0].weight = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b0/kernel']))
model.cell6.conv0[0].bias   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b0/bias']))
model.cell6.net[0].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b1/kernel']))
model.cell6.net[0].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b1/bias']))
model.cell6.net[2].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b2/kernel']))
model.cell6.net[2].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b2/bias']))
model.cell6.net[4].weight   = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b3/kernel']))
model.cell6.net[4].bias     = torch.nn.Parameter(torch.from_numpy(out['test_sr_evaluator_i1_b0_g/cell6/conv_f64_k3_b4_isskip_b3/bias']))

torch.save(model.state_dict(), './FALSR-A-torch.pth.tar')

​ 最后,前面已经完成了TF模型转Pytorch,现在最重要的就是验证了。话不多说,直接上代码:

def main():
    lr_path = './dataset/Set5/img_001_SRF_2_LR.png'
    calculate_lr_img = scipy.misc.imread(lr_path, mode='RGB')
    ypbprt = sc.rgb2ypbpr(calculate_lr_img / 255.0)[..., 0][np.newaxis, ...][np.newaxis]

    scale = 2
    size = calculate_lr_img.shape
    x_scale = scipy.misc.imresize(calculate_lr_img, [size[0] * scale, size[1] * scale], interp='bicubic', mode=None)
    pbpr = sc.rgb2ypbpr(x_scale / 255)[..., 1:]
    pbpr = np.transpose(pbpr, (2, 0,1))[np.newaxis, ...]
    pbpr = torch.from_numpy(pbpr.astype(np.float32))

    model = FALSRA()
    model.load_state_dict(torch.load('FALSR-A-torch.pth.tar'))
    model.eval()

    inputs = torch.from_numpy(ypbprt.astype(np.float32))
    with torch.no_grad():
        output = model(inputs, pbpr)

    pred = output[0].detach().numpy().astype(np.uint8)
    pred = np.transpose(pred, (1,2,0))
    cv2.imshow('pred', pred[:,:,::-1])
    cv2.waitKey()
    cv2.destroyAllWindows()

​ OK,终于有一篇流水账式的记录完成。事实上也没有好详细介绍的,看代码就完全懂了。模型转换关键性的东西已经全部呈现,到此结束,祝君好运。

附录

这里附上如何将pb文件中的图导出来并通过tensorboard查看。同样直接上代码:

import tensorflow as tf

model = 'FALSR-A.pb'
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

通过上述代码,可以在log文件中看一个关于图的文件。注:不同人运行后文件名会不一致。

9e8cc22699af29dad3979da4bdc3a891.png

最后通过tensorboard --logdir=./log查看FALSR-A的图,附图如下。知道了FALSR-A的图结构后,剩下的就是根据该图在Pytorch框架下重写网络架构了。祝君好运。

2f2c9289043eb7a0fa1218708958aa5b.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值