FusionNet-Model(pytorch版本)

F u s i o n N e t − M o d e l ( p y t o r c h 版 本 ) FusionNet-Model(pytorch版本) FusionNetModel(pytorch)


训练、验证代码逻辑




All.ipynb


在这里插入图片描述

在这里插入图片描述


import torch.nn as nn
import torch
def conv_block(in_dim,out_dim,act_fn,stride=1):
    model = nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=stride, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def conv_trans_block(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def conv_block_3(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        conv_block(out_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model


class Conv_residual_conv(nn.Module):
    def __init__(self, in_dim, out_dim, act_fn):
        super(Conv_residual_conv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        act_fn = act_fn

        self.conv_1 = conv_block(self.in_dim, self.out_dim, act_fn)
        self.conv_2 = conv_block_3(self.out_dim, self.out_dim, act_fn)
        self.conv_3 = conv_block(self.out_dim, self.out_dim, act_fn)

    def forward(self, input):
        conv_1 = self.conv_1(input)
        conv_2 = self.conv_2(conv_1)
        res = conv_1 + conv_2
        conv_3 = self.conv_3(res)

        return conv_3
class Fusionnet(nn.Module):

    def __init__(self, input_nc, output_nc, ngf, out_clamp=None):
        super(Fusionnet, self).__init__()

        self.out_clamp = out_clamp
        self.in_dim = input_nc
        self.out_dim = ngf
        self.final_out_dim = output_nc

        act_fn = nn.ReLU()
        act_fn_2 = nn.ELU(inplace=True)

        # encoder
        self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
        self.pool_1 = conv_block(self.out_dim, self.out_dim, act_fn, 2)
        self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
        self.pool_2 = conv_block(self.out_dim * 2, self.out_dim * 2, act_fn, 2)
        self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
        self.pool_3 = conv_block(self.out_dim * 4, self.out_dim * 4, act_fn, 2)
        self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
        self.pool_4 = conv_block(self.out_dim * 8, self.out_dim * 8, act_fn, 2)

        # bridge
        self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)

        # decoder
        self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
        self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
        self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
        self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
        self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
        self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
        self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
        self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)

        # output
        self.out = nn.Conv2d(self.out_dim, self.final_out_dim, kernel_size=3, stride=1, padding=1)

    def forward(self, input):
        print('input:', input.size())
        down_1 = self.down_1(input);print('down_1:', down_1.size())
        pool_1 = self.pool_1(down_1);print('pool_1:', pool_1.size())
        down_2 = self.down_2(pool_1);print('down_2:', down_2.size())
        pool_2 = self.pool_2(down_2);print('pool_2:', pool_2.size())
        down_3 = self.down_3(pool_2);print('down_3:', down_3.size())
        pool_3 = self.pool_3(down_3);print('pool_3:', pool_3.size())
        down_4 = self.down_4(pool_3);print('down_4:', down_4.size())
        pool_4 = self.pool_4(down_4);print('pool_4:', pool_4.size())

        bridge = self.bridge(pool_4);print('bridge:', bridge.size())

        deconv_1 = self.deconv_1(bridge);print('deconv_1:', deconv_1.size())
        skip_1 = (deconv_1 + down_4) / 2;print('skip_1:', skip_1.size())
        up_1 = self.up_1(skip_1);print('up_1:', up_1.size())
        deconv_2 = self.deconv_2(up_1);print('deconv_2:', deconv_2.size())
        skip_2 = (deconv_2 + down_3) / 2;print('skip_2:', skip_2.size())
        up_2 = self.up_2(skip_2);print('up_2:', up_2.size())
        deconv_3 = self.deconv_3(up_2);print('deconv_3:', deconv_3.size())
        skip_3 = (deconv_3 + down_2) / 2;print('skip_3:', skip_3.size())
        up_3 = self.up_3(skip_3);print('up_3:', up_3.size())
        deconv_4 = self.deconv_4(up_3);print('deconv_4:', deconv_4.size())
        skip_4 = (deconv_4 + down_1) / 2;print('skip_4:', skip_4.size())
        up_4 = self.up_4(skip_4);print('up_4:', up_4.size())

        out = self.out(up_4);print('out:', out.size())

        return out
# 随机生成输入数据
rgb = torch.randn(1, 3, 352, 480)
# 定义网络
net = Fusionnet(3, 12, 64)
# 前向传播
out = net(rgb)
# 打印输出大小
print('-----'*5)
print(out.shape)
print('-----'*5)

在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个基于 PyTorch 的图像融合的示例代码,它将两张大小相同的图像进行融合: ```python import torch import torch.nn as nn import torch.optim as optim from PIL import Image # 定义数据集 class ImageDataset(torch.utils.data.Dataset): def __init__(self, img_path1, img_path2): self.img_path1 = img_path1 self.img_path2 = img_path2 self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) def __getitem__(self, index): img1 = Image.open(self.img_path1) img2 = Image.open(self.img_path2) img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2 def __len__(self): return 1 # 定义模型 class FusionNet(nn.Module): def __init__(self): super(FusionNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv4 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, img1, img2): x = torch.cat([img1, img2], dim=1) x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.conv4(x) return x # 训练数据集路径 img_path1 = "/path/to/image1.jpg" img_path2 = "/path/to/image2.jpg" # 定义数据加载器 dataset = ImageDataset(img_path1, img_path2) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) # 定义模型、损失函数和优化器 model = FusionNet() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 num_epochs = 100 for epoch in range(num_epochs): running_loss = 0.0 for data in dataloader: img1, img2 = data optimizer.zero_grad() output = model(img1, img2) loss = criterion(output, img1) loss.backward() optimizer.step() running_loss += loss.item() print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, running_loss)) # 图像融合 model.eval() with torch.no_grad(): img1 = Image.open(img_path1) img2 = Image.open(img_path2) img1 = dataset.transform(img1).unsqueeze(0) img2 = dataset.transform(img2).unsqueeze(0) output = model(img1, img2) output = output.squeeze(0) output = output.permute(1, 2, 0) output = (output + 1) / 2.0 output = output.detach().numpy() * 255.0 output = output.astype('uint8') output = Image.fromarray(output) output.save("/path/to/fused_image.jpg") ``` 这个示例代码中,我们首先定义了一个数据集类 `ImageDataset`,它将两张图片进行数据预处理,并返回给训练器。接着,我们定义了一个融合网络 `FusionNet`,它包含了四个卷积层和一个 ReLU 激活函数。在训练过程中,我们使用均方误差损失函数和 Adam 优化器进行训练。最后,我们使用训练好的模型将两张输入图像进行融合,并将输出保存为一张新的图像。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值