【pytorch】torch.nn.Module.load_state_dict详解

参考博客:
https://blog.csdn.net/weixin_40522801/article/details/106563354
https://blog.csdn.net/yangwangnndd/article/details/100207686

函数定义:

load_state_dict(state_dict, strict=True)

作用
使用 state_dict 反序列化模型参数字典。用来加载模型参数。将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中。
概况:给模型对象加载训练好的模型参数,即加载模型参数

关于state_dict
在PyTorch中,一个torch.nn.Module模型中的可学习参数(比如weights和biases),模型的参数通过model.parameters()获取。而state_dict就是一个简单的Python dictionary,其功能是将每层与层的参数张量之间一一映射。
注意,只有包含了可学习参数(卷积层、线性层等)的层和已注册的命令(registered buffers,比如batchnorm的running_mean)才有模型的state_dict入口。优化方法目标(torch.optim)也有state_dict,其中包含的是关于优化器状态的信息和使用到的超参数。
因为state_dict目标是Python dictionaries,所以它们可以很轻松地实现保存、更新、变化和再存储,从而给PyTorch模型和优化器增加了大量的模块化(modularity)。

使用示例:

model.load_state_dict(torch.load('pose_dekr_hrnetw32_coco.pth'), strict=True)

官方函数说明

Copies parameters and buffers from state_dict into this module
and its descendants. If strict is True, then the keys of
state_dict must exactly match the keys returned by this
module’s state_dict() function.
从函数接收的参数state_dict中将参数和缓冲拷贝到当前这个模块及其子模块中.
如果函数接受的参数strict是True,那么state_dict的关键字必须确切地严格地和
该模块的state_dict()函数返回的关键字相匹配.

Parameters 参数

state_dict (dict) – a dict containing parameters and persistent buffers.
state_dict (字典类型) – 一个包含参数和持续性缓冲的字典
往往是pytorch模型pth文件
strict (布尔类型, 可选) – 该参数用来指明是否需要强制严格匹配,:state_dict中的关键字是否需要和该模块的state_dict()方法返回的关键字强制严格匹配.默认值是True

Returns 返回

返回类型:NamedTuple with missing_keys and unexpected_keys fields
missing_keys is a list of str containing the missing keys
missing_keys是一个字符串的列表,该列表包含了所有缺失的关键字.
unexpected_keys is a list of str containing the unexpected keys
unexpected_keys是一个字符串的列表,该列表包含了意料之外的关键字,
import scipy.io import torch from torch_geometric.data import Data from torch_geometric.nn import GNNExplainer from torch_geometric.utils import to_networkx import matplotlib.pyplot as plt import networkx as nx from edge_index_to_adj_node_edge import edge_index_to_adj_node_edge # 加载 .mat 文件 mat_file = "D:/GNN/wy0303-62OLD/data/data0318/data0318.mat" # 替换为你的 .mat 文件路径 mat_data = scipy.io.loadmat(mat_file) # 提取数据 edge_index = mat_data['edge_index'] # 边的索引 n_s = mat_data['n_s'] # 源节点特征 n_d = mat_data['n_d'] # 目标节点特征 m_s = mat_data['m_s'] # 可能是边特征 m_d = mat_data['m_d'] # 可能是边特征 E = mat_data['E'] # 可能是一些额外的图特征 sc = mat_data['sc'] # 可能是额外的图特征 adj_matrix = edge_index_to_adj_node_edge(edge_index) # 如果存在邻接矩阵 # 转换为 PyTorch 张量 edge_index_tensor = torch.tensor(edge_index, dtype=torch.long) node_features_tensor = torch.cat((torch.tensor(n_s, dtype=torch.float), torch.tensor(n_d, dtype=torch.float)), dim=0) # 处理节点特征(聚合特征维度,得到一个二维张量) node_features_processed = node_features_tensor.mean(dim=(1, 2, 3)) # 你可以根据需要修改此方法 # 创建 PyTorch Geometric 数据对象 data = Data(x=node_features_processed, edge_index=edge_index_tensor) # 加载预训练的模型 model = torch.load('D:/GNN/wy0303-62OLD/best_model.pt') model.eval() # 设置为评估模式 # 选择一个节点进行解释 node_idx = 0 # 选择图中的一个节点进行解释 # 使用 GNNExplainer 进行解释 explainer = GNNExplainer(model, epochs=200) print(f"edge_index type: {type(data.edge_index)}") print(f"edge_index shape: {data.edge_index.shape}") # 在调用 explain_node 时传递所有额外的输入(包括 edge_index) # 这里我们明确传递 edge_index 给模型和 explainer # n_s, n_d, m_s, m_d, E, sc,edge_index,adj_matrix explanation = explainer.explain_node( node_idx, data.x, data.edge_index, m_d=m_d, m_s=m_s, E=E, sc=sc, adj_matrix=adj_matrix, ) # 可视化结果 G = to_networkx(data) # 转换为 NetworkX 图对象 pos = nx.spring_layout(G) # 获取图的布局 # 可视化图 plt.figure(figsize=(8, 8)) nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500) nx.draw_networkx_edges(G, pos, edgelist=explanation.edge_mask, edge_color='r'
最新发布
04-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值