tensoflow加载pytorch的pt文件参数

@[TOC]tensoflow加载pytorch的pt文件参数

tensoflow加载pytorch的pt文件参数

主要的思路是将pytorch参数改成onnx,然后转成tensorflow的pt文件

import torch
from torch.utils.data import DataLoader
import tensorflow as tf

from model import CNN_Mnist_Relu
from torchvision import transforms, datasets
import onnx
from onnx_tf.backend import prepare
from tensorflow.keras.models import load_model
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train = datasets.MNIST('./data/', train=True, transform=transforms.Compose([transforms.ToTensor(), ]), download=True)
train_loader = DataLoader(train, batch_size=128)
dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.to(device)

model_param = torch.load('./e_data_3/fedavg/global_model.pt')
target_model = CNN_Mnist_Relu().to(device)
target_model.load_state_dict(model_param)

torch.onnx.export(target_model, images, './fe.onnx', input_names=['input'], output_names=['output'])



# 加载MNIST数据集

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

model_onnx = onnx.load('fe.onnx')
ir_version = model_onnx.ir_version
model_onnx.ir_version = 6
tf_rep = prepare(model_onnx)
tf_rep.export_graph('output.pt')

'''
这个注释的一部分,还没有测试成功
tf_graph = load_pb('output.pt')
sess = tf.Session(graph=tf_graph)

# 获取输入和输出张量的引用
input_tensor = sess.graph.get_tensor_by_name('input:0')  # 替换为你的输入张量名称
output_tensor = sess.graph.get_tensor_by_name('output:0')  # 替换为你的输出张量名称

# 准备输入数据(使用MNIST测试数据)
batch_size = 1
input_data = images.cpu()  # 假设输入是 28x28 图像

# 在会话中运行模型
output_data = sess.run(output_tensor, feed_dict={input_tensor: input_data})
print(output_data)
# 输出模型的预测结果
predicted_class = tf.argmax(output_data, axis=1).eval(session=sess)
# 打印预测结果
print("Predicted Class:", predicted_class[0])
# 关闭TensorFlow会话
sess.close()
'''

主要的环境搭配
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

JohntyZhou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值