Pytorch转tflite

18 篇文章 3 订阅
2 篇文章 0 订阅

目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。
最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。

经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。

转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。

转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.

下面是一个例子,假设转换的是一个两层的CNN网络。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable


class PytorchNet(nn.Module):
    def __init__(self):
        super(PytorchNet, self).__init__()
        conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, groups=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.feature = nn.Sequential(conv1, conv2)
        self.init_weights()

    def forward(self, x):
        return self.feature(x)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight.data, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


def KerasNet(input_shape=(224, 224, 3)):
    image_input = keras.layers.Input(shape=input_shape)
    # conv1
    network = keras.layers.Conv2D(
        32, (3, 3), strides=(2, 2), padding="valid")(image_input)
    network = keras.layers.BatchNormalization(
        trainable=False, fused=False)(network)
    network = keras.layers.Activation("relu")(network)
    network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

    # conv2
    network = keras.layers.Conv2D(
        64, (3, 3), strides=(1, 1), padding="valid")(network)
    network = keras.layers.BatchNormalization(
        trainable=False, fused=True)(network)
    network = keras.layers.Activation("relu")(network)
    network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

    model = keras.Model(inputs=image_input, outputs=network)

    return model


class PytorchToKeras(object):
    def __init__(self, pModel, kModel):
        super(PytorchToKeras, self)
        self.__source_layers = []
        self.__target_layers = []
        self.pModel = pModel
        self.kModel = kModel
        tf.keras.backend.set_learning_phase(0)

    def __retrieve_k_layers(self):
        for i, layer in enumerate(self.kModel.layers):
            if len(layer.weights) > 0:
                self.__target_layers.append(i)

    def __retrieve_p_layers(self, input_size):

        input = torch.randn(input_size)
        input = Variable(input.unsqueeze(0))
        hooks = []

        def add_hooks(module):

            def hook(module, input, output):
                if hasattr(module, "weight"):
                    # print(module)
                    self.__source_layers.append(module)

            if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
                hooks.append(module.register_forward_hook(hook))

        self.pModel.apply(add_hooks)

        self.pModel(input)
        for hook in hooks:
            hook.remove()

    def convert(self, input_size):
        self.__retrieve_k_layers()
        self.__retrieve_p_layers(input_size)

        for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
            print(source_layer)
            weight_size = len(source_layer.weight.data.size())
            transpose_dims = []
            for i in range(weight_size):
                transpose_dims.append(weight_size - i - 1)
            if isinstance(source_layer, nn.Conv2d):
                transpose_dims = [2,3,1,0]
                self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
                ).transpose(transpose_dims), source_layer.bias.data.numpy()])
            elif isinstance(source_layer, nn.BatchNorm2d):
                self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
                                                              source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])

    def save_model(self, output_file):
        self.kModel.save(output_file)

    def save_weights(self, output_file):
        self.kModel.save_weights(output_file, save_format='h5')


pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

torch.save(pytorch_model, 'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")


# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
    "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)

Reference:
https://www.jiqizhixin.com/articles/2018-07-18-7

  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
TFLiteTensorFlow Lite)和TensorFlow是Google开发的两个机器学习框架,而PyTorch是Facebook开发的另一个机器学习框架。 TFLiteTensorFlow的移动和嵌入式部署解决方案。它专为在资源受限的设备上进行机器学习推理而设计,如移动设备、嵌入式系统和物联网设备。TFLite提供了一种轻量级的运行时库,可以将TensorFlow模型换为高效的格式,并支持在资源有限的设备上进行实时推理。TFLite还提供了一些优化技术,如模型量化、模型缩减和GPU加速,以提高模型的运行效率和性能。 TensorFlow是一个强大的开源机器学习框架,它提供了一系列丰富的API和工具,用于构建、训练和部署机器学习模型。TensorFlow支持多种机器学习任务,包括图像识别、自然语言处理、推荐系统和时间序列分析等。由于其广泛的支持和社区,TensorFlow成为了许多研究人员和工程师首选的框架。 PyTorch是一个动态计算图机器学习框架,具有直观易用的接口。它与Python的语法非常接近,使得用户可以以一种更自然的方式定义和操作他们的模型。PyTorch还提供了一种称为TorchScript的功能,可将训练好的模型换为一个可以在生产环境中进行推理的格式。此外,PyTorch还具有灵活、高效的GPU加速功能,能够发挥最大的计算性能。 总体而言,TFLite适用于在资源受限的设备上进行机器学习推理,而TensorFlow适用于构建和训练机器学习模型,PyTorch则提供了一种动态计算图的机器学习框架,使得用户可以以一种直观易用的方式定义和操作模型。每个框架都有其独特的优点,选择哪一个取决于具体的使用场景和个人偏好。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值