@[TOC]tensoflow加载pytorch的pt文件参数
tensoflow加载pytorch的pt文件参数
主要的思路是将pytorch参数改成onnx,然后转成tensorflow的pt文件
import torch
from torch.utils.data import DataLoader
import tensorflow as tf
from model import CNN_Mnist_Relu
from torchvision import transforms, datasets
import onnx
from onnx_tf.backend import prepare
from tensorflow.keras.models import load_model
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train = datasets.MNIST('./data/', train=True, transform=transforms.Compose([transforms.ToTensor(), ]), download=True)
train_loader = DataLoader(train, batch_size=128)
dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.to(device)
model_param = torch.load('./e_data_3/fedavg/global_model.pt')
target_model = CNN_Mnist_Relu().to(device)
target_model.load_state_dict(model_param)
torch.onnx.export(target_model, images, './fe.onnx', input_names=['input'], output_names=['output'])
# 加载MNIST数据集
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
model_onnx = onnx.load('fe.onnx')
ir_version = model_onnx.ir_version
model_onnx.ir_version = 6
tf_rep = prepare(model_onnx)
tf_rep.export_graph('output.pt')
'''
这个注释的一部分,还没有测试成功
tf_graph = load_pb('output.pt')
sess = tf.Session(graph=tf_graph)
# 获取输入和输出张量的引用
input_tensor = sess.graph.get_tensor_by_name('input:0') # 替换为你的输入张量名称
output_tensor = sess.graph.get_tensor_by_name('output:0') # 替换为你的输出张量名称
# 准备输入数据(使用MNIST测试数据)
batch_size = 1
input_data = images.cpu() # 假设输入是 28x28 图像
# 在会话中运行模型
output_data = sess.run(output_tensor, feed_dict={input_tensor: input_data})
print(output_data)
# 输出模型的预测结果
predicted_class = tf.argmax(output_data, axis=1).eval(session=sess)
# 打印预测结果
print("Predicted Class:", predicted_class[0])
# 关闭TensorFlow会话
sess.close()
'''
主要的环境搭配