import torch
import torch.nn as nn
class ModelA(nn.Module):
def __init__(self):
super(ModelA, self).__init__()
self.A = nn.Linear(2, 3)
def forward(self, A):
pass
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.model_a = ModelA()
self.A = nn.Linear(2, 3)
def forward(self, x):
pass
print("Model")
modelA = ModelA()
modelA_dict = modelA.state_dict()
print('-' * 80)
for key in sorted(modelA_dict.keys()):
parameter = modelA_dict[key]
print(key)
print(parameter.size())
print(parameter)
modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*80)
for key in sorted(modelB_dict.keys()):
print('-'*20)
parameter = modelB_dict[key]
print(type(key), key)
print(parameter.size())
print(parameter)
print('-'*20)
print('-'*80)
pretrained_dict = modelA_dict
model_dict = modelB_dict
pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}
model_dict.update(pretrained_dict)
modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
————————————————
版权声明:本文为CSDN博主「weixin_39945810」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_39945810/article/details/112936675
pytorch加载部分模型
最新推荐文章于 2024-07-26 20:52:07 发布