pytorch模型转到TensorFlow lite:pytorch->onnx->tensorflow->tensorflow lite

本文详细记录了将PyTorch模型转换为TensorFlow Lite的完整过程,包括PyTorch到ONNX,再到TensorFlow,最后转换成TensorFlow Lite。在每个步骤中,作者都提供了关键代码并强调了需要注意的事项,如OPSET版本选择、模型测试等。在TensorFlow到TensorFlow Lite的转换中,特别提到了需要指定支持的操作集来解决不兼容问题。
摘要由CSDN通过智能技术生成

现在很多算法都是用pytorch框架训练的,但是在移动端部署很多又使用TensorFlow lite,因此需要将pytorch模型转换到TensorFlow lite。

将pytorch模型转到TensorFlow lite的流程是pytorch->onnx->tensorflow->tensorflow lite,本文记录一下踩坑的过程。

1、pytorch转onnx

这一步比较简单,使用pytorch自带接口就行。不过有一点需要注意的,就是opset版本,可能会影响后续的转换。

    os.environ['CUDA_VISIBLE_DEVICES']='0'
    model_path = 'model.pth'
    model = architecture.IMDN_RTC(upscale=2).cuda()
    model_dict = utils.load_state_dict(model_path)
    model.load_state_dict(model_dict, strict=True)

    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False

    dummy_input = torch.rand(1, 3, 224, 224).cuda()
    input_names = ["input"]
    #output_names = ["output1", "output2", "output3"]
    output_names = ["output"]
    #使用pytorch的onnx模块来进行转换
    #opset 10转换后,使用onnxruntime运行,在pixelshuffle处会出错
    torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, verbose=True, 
            input_names=input_names, output_names=output_names, 
            dynamic_axes={'input': [0, 2, 3], 'output': [0, 2, 3]})

    session = onnxruntime.InferenceSession("model.onnx")
    input_name = session.get_inputs()[0].name
    #output_name = session.get_outputs()[0].name
    output_names = [s.name for s in session.get_outputs()]
    input_shape = session.get_inputs()[0].shape

    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = torch.from_numpy(img).unsqueeze(0).float()
    res = session.run(output_names, {input_name: img.cpu().numpy()})
    tmp = res[0]
    tmp = np.clip(tmp[0], 0, 1)
    img = np.array(tmp*255, dtype=np.uint8)
    img = np.transpose(img, (1, 2, 0))[:,:,::-1]
    cv2.imwrite('tmp.jpg', img)

torch.onnx.export后,就得到了onnx模型,后面的代码是使用onnxruntime测试转换后的onnx模型。建议每一步转换后,都测试一下转换后模型的结果,确保每一步都是正确的。

2、onnx转TensorFlow

需要安装onnx-tensorflow进行转换。

from onnx_tf.backend import prepare
import onnx
import tensorflow as tf
if __name__ == '__main__':
    onnx_model = onnx.load("model.onnx")  # load onnx model
    tf_rep = prepare(onnx_model)  # prepare tf representation
    tf_rep.export_graph("model.tf")  # export the model

    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = torch.from_numpy(img).unsqueeze(0).float()
    input = img.numpy()
    if 0:
        output = tf_rep.run(input)  # run the loaded model
        res = output.output[0]
        res = np.clip(res, 0, 1)
        im = np.array(res*255, dtype=np.uint8)
        im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
        cv2.imwrite('tfres.jpg', im1)
    else:
        saved_model = tf.saved_model.load("model.tf")
        detect_fn = saved_model.signatures["serving_default"]
        output = detect_fn(tf.constant(input))
        tmp = np.array(output['output'])[0]
        res = np.clip(tmp, 0, 1)
        im = np.array(res*255, dtype=np.uint8)
        im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
        cv2.imwrite('savedmodelres.jpg', im1)

转换部分就是前三行代码,后面是对TensorFlow模型的测试,确保转换结果没有问题。

3、TensorFlow转TensorFlow lite

没想到这一步是比较坑的,换了几个TensorFlow版本,最终使用tf2.5,转换成功了,参考issue

import tensorflow as tf

if __name__ == '__main__':
    # Convert the model
    converter = tf.lite.TFLiteConverter.from_saved_model('model.tf') # path to the SavedModel directory
    converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
    ]
    tflite_model = converter.convert()

    # Save the model.
    with open('model.tflite', 'wb') as f:
      f.write(tflite_model)

    # test tflite model
    interpreter = tf.lite.Interpreter(model_path='model.tflite')
    #my_signature = interpreter.get_signature_runner()
    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = img[np.newaxis, :]
    #output = my_signature(tf.constant(img))
    print()

    interpreter.resize_tensor_input(0, [1, 3, 256, 256])
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test the model on random input data.
    #input_shape = input_details[0]['shape_signature']
    #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    input_data = img.astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])
    res = np.clip(output_data[0], 0, 1)
    im = np.array(res*255, dtype=np.uint8)
    im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
    cv2.imwrite('tfliteres.jpg', im1)

这里需要注意一点,converter.target_spec.supported_ops这个需要加上,不然有些op在TensorFlow lite中不支持,转换不成功。

 

 

 

 

 

 

幻读是指在一个事务中,前后两次相同的查询语句返回了不同的结果行数或数据内容。幻读问题通常出现在并发环境下,当一个事务在读取数据时,另一个事务对数据进行了插入、更新或删除操作,导致前后两次读取的数据不一致。 MySQL 提供了以下几种方式来解决幻读问题: 1. 使用锁机制:通过使用共享锁(S锁)或排他锁(X锁)来保证读取数据和修改数据的互斥。可以使用 SELECT ... FOR UPDATE 语句在读取数据时加上排他锁,或者使用 SELECT ... LOCK IN SHARE MODE 语句在读取数据时加上共享锁。 2. 使用事务隔离级别:MySQL 提供了多个事务隔离级别,包括读未提交(Read Uncommitted)、读已提交(Read Committed)、可重复读(Repeatable Read)和串行化(Serializable)。将事务隔离级别设置为可重复读或串行化可以解决幻读问题,但会增加并发性能开销。 3. 使用间隙锁(Gap Locks):在可重复读或串行化隔离级别下,MySQL 可以使用间隙锁来防止幻读。间隙锁是在索引范围内的空隙中设置的锁,用于防止其他事务在该范围内插入新的数据。 4. 使用 MVCC(Multi-Version Concurrency Control):MVCC 是通过版本号或时间戳来控制事务的并发访问。在可重复读或串行化隔离级别下,MySQL 使用 MVCC 来为每个事务提供一个独立的数据快照,避免了幻读问题。 需要根据具体的业务需求和并发情况选择适合的解决方案。同时,还需要注意合理设计数据库索引、优化查询语句等措施,以提高数据库的并发性能和减少幻读的发生。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值