caffe模型 转 pytorch 模型

最近基于 caff2onnx 做了部分修改,完成了caffe 转 pytorch的模型代码, 

 

主代码 , 需要自己构建 pytorch 的Net 架构, 同时 net各层的名字要与 caffe的各层对应。

    graph, params = LoadCaffeModel(caffe_graph_path,caffe_params_path)
    #print(graph)

    net_pytorch = NET()

    net_pytorch.eval()


    print('start convert')
    caffe2pytorch = Caffe2Pytorch(graph, params, net_pytorch)
    caffe2pytorch.convert()
    print("convert finish.")
    torch.save(net_pytorch.state_dict(), save_path)
    print("save finish.")

子函数代码

class Caffe2Pytorch():
    def __init__(self, net, model, pytorch_net):
        # 初始化一个c2oGraph对象
        # 网络和参数
        self.netLayerCaffe = self.GetNetLayerCaffe(net)
        self.netModelCaffe = self.GetNetModelCaffe(model)

        # 模型的输入名和输入维度
        self.model_input_name = []
        self.model_input_shape = []

        self.pytorch_net = pytorch_net
        self.state_dict = {}

    # 获取网络层
    def GetNetLayerCaffe(self, net):
        if len(net.layer) == 0 and len(net.layers) != 0:
            return net.layers
        elif len(net.layer) != 0 and len(net.layers) == 0:
            return net.layer
        else:
            print("prototxt layer error")
            return -1

    # 获取参数层
    def GetNetModelCaffe(self, model):
        if len(model.layer) == 0 and len(model.layers) != 0:
            return model.layers
        elif len(model.layer) != 0 and len(model.layers) == 0:
            return model.layer
        else:
            print("caffemodel layer error")
            return -1


    def match(self, caffe_layer_name, pS, pD):
        index = 0
        for name in self.pytorch_net.state_dict():
            if name.find(caffe_layer_name+".") == 0:
                print("match success:  caffe name:", caffe_layer_name,  " py name:", name)
                newD = [p for p in pD[index]]
                newD = np.array(newD)
                shape = tuple([s for s in pS[index]])
                newD = newD.reshape(shape)
                #print("newD:", newD.shape)
                self.state_dict[name] = torch.from_numpy(newD)
                index += 1

            if index == len(pS):
                break


    def convert(self):
        ParamShape = []
        ParamData = []
        # 根据这个layer名找出对应的caffemodel中的参数
        for i, model_layer in enumerate(self.netModelCaffe):
            Params = copy.deepcopy(model_layer.blobs)
            ParamShape = [p.shape.dim for p in Params]
            ParamData = [p.data for p in Params]

            if model_layer.type == "BatchNorm" or model_layer.type == "BN":
                if len(ParamShape) == 3:
                    # 如果是bn层,params为[mean, var, s],则需要把mean和var除以滑动系数s
                    ParamShape = ParamShape[:-1]
                    ParamData = [
                        [q / (Params[-1].data[0])
                         for q in p.data] if i == 0 else
                        [q / (Params[-1].data[0] + 1e-5) for q in p.data]
                        for i, p in enumerate(Params[:-1])
                    ]  # with s
                elif len(ParamShape) == 2 and len(ParamShape[0]) == 4:
                    ParamShape = [[ParamShape[0][1]], [ParamShape[1][1]]]
                    ParamData = [[q / 1. for q in p.data] if i == 0 else
                                 [q / (1. + 1e-5) for q in p.data]
                                 for i, p in enumerate(Params)]
                if self.netModelCaffe[i+1].type == "Scale":
                    Params = copy.deepcopy(self.netModelCaffe[i+1].blobs)
                    ParamShape1 = [p.shape.dim for p in Params]
                    ParamData1 = [p.data for p in Params]

                    ParamShape1.extend(ParamShape)
                    ParamData1.extend(ParamData)
                    ParamShape = ParamShape1
                    ParamData = ParamData1

            print("caffe param name:", model_layer.name, " param shape :", ParamShape)

            layer_name = model_layer.name
            layer_name = layer_name.replace("/", "_")
            self.match(layer_name, ParamShape, ParamData)

        self.pytorch_net.load_state_dict(self.state_dict, strict=False)

        return ParamShape, ParamData

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

NineDays66

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值