问题描述:在做无监督图像融合时涉及到了编码器和解码器。编码器输出的特征图,经过处理后要在通过解码器。这就要求参数文件要分开读取。
解决问题的核心;pytorch是以字典的方式存储参数(层的名字:对应参数)
下面进行举例:
class decode(): # 解码器
def __init__(self):
# initialize model
self.device = torch.device('cpu') # 设备选择
self.model = Decodenet()
self.model_path = os.path.join(os.getcwd(), "nets", "parameters", "lp+lssim_se_sf_net_times30.pkl") # 编码器参数的路径
self.save_model = torch.load(self.model_path, map_location=self.device)
self.model_dict = self.model.state_dict() # 模型key
self.state_dict = {k: v for k, v in self.save_model.items() if k in self.model_dict.keys()} # 这里还要修改
self.model_dict.update(self.state_dict)
self.model.load_state_dict(self.model_dict)
self.model.to(self.device)
self.model.eval()
def reduction(self, f_m):
img = self.model(f_m)
return img
class Decodenet(nn.Module): # 解码的网络
def __init__(self):
super(Decodenet, self).__init__()
self.conv_decode_1 = self.conv_block(64, 64)
self.conv_decode_2 = self.conv_block(64, 32)
self.conv_decode_3 = self.conv_block(32, 16)
self.conv_decode_4 = self.conv_block(16, 1)
def forward(self , se_cat3):
with torch.no_grad():
decode_block1 = self.conv_decode_1(se_cat3)
decode_block2 = self.conv_decode_2(decode_block1)
decode_block3 = self.conv_decode_3(decode_block2)
output = self.conv_decode_4(decode_block3)
return output
@staticmethod
def conv_block(in_channels, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block