查看模型字典
for k, v in pretrainedd_dict.items():
print(k,end='\t')
print()
for k, v in D_dict.items():
print(k,end='\t')
pretrainedd_dict:
seq.0.weight seq.0.bias seq.3.weight seq.3.bias seq.6.weight seq.6.bias
D_dict:
seq.0.weight seq.0.bias seq.3.weight seq.3.bias seq.6.weight seq.6.bias last.weight last.bias
取出已训练的模型encoder初始化模型D
# 使用encoder初始化D的部分参数
D_dict = D.state_dict()#取出自己网络的参数字典
D_filename = '{}/encoder_{}.pth'.format(args.model_dir, args.start_step)
pretrainedd_dict = torch.load(D_filename)#加载预训练网络的参数字典
for k,v in pretrainedd_dict.items():
if v.size()==D_dict[k].size():
D_dict[k] = pretrainedd_dict[k]
D.load_state_dict(D_dict)
print('D初始化完成')