基于VGGBN的自编码器

import pandas as pd
import torch
from torch import nn
from torchvision import models
from torchvision.models import VGG16_BN_Weights
from gen_AED_dataset import MyAEDataset
from torch.utils.data.dataloader import DataLoader


class EncoderVGG(nn.Module):
    channels_in = 3
    channels_code = 512

    def __init__(self):
        super(EncoderVGG, self).__init__()

        vgg = models.vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)
        del vgg.classifier
        del vgg.avgpool

        self.encoder = self._encodify_(vgg)

    def _encodify_(self, encoder):

        modules = nn.ModuleList()
        for module in encoder.features:
            if isinstance(module, nn.MaxPool2d):
                module_add = nn.MaxPool2d(kernel_size=module.kernel_size,
                                          stride=module.stride,
                                          padding=module.padding,
                                          return_indices=True)

                modules.append(module_add)
            else:
                modules.append(module)

        return modules

    def forward(self, x):

        pool_indices = []
        x_current = x
        for module_encode in self.encoder:
            output = module_encode(x_current)
            # 如果模块是池,有两个输出,第二个是池索引
            if isinstance(output, tuple) and len(output) == 2:
                x_current = output[0]
                pool_indices.append(output[1])
            else:
                x_current = output
        return x_current, pool_indices


class DecoderVGG(nn.Module):
    channels_in = EncoderVGG.channels_code
    channels_out = 3

    def __init__(self, encoder):
        super(DecoderVGG, self).__init__()
        self.decoder = self._invert_(encoder)

    def _invert_(self, encoder):
        modules_transpose = []
        for module in reversed(encoder):
            if isinstance(module, nn.Conv2d):
                kwargs = {'in_channels': module.out_channels, 'out_channels': module.in_channels,
                          'kernel_size': module.kernel_size, 'stride': module.stride,
                          'padding': module.padding}
                module_transpose = nn.ConvTranspose2d(**kwargs)
                module_norm = nn.BatchNorm2d(module.in_channels)
                module_act = nn.ReLU(inplace=True)
                modules_transpose += [module_transpose, module_norm, module_act]
            elif isinstance(module, nn.MaxPool2d):
                kwargs = {'kernel_size': module.kernel_size, 'stride': module.stride,
                          'padding': module.padding}
                module_transpose = nn.MaxUnpool2d(**kwargs)
                modules_transpose += [module_transpose]
        return nn.ModuleList(modules_transpose[:-2])

    def forward(self, x, pool_indices):
        x_current = x
        k_pool = 0
        reversed_pool_indices = list(reversed(pool_indices))
        for module_decode in self.decoder:
            if isinstance(module_decode, nn.MaxUnpool2d):
                x_current = module_decode(x_current, indices=reversed_pool_indices[k_pool])
                k_pool += 1
            else:
                x_current = module_decode(x_current)

        return x_current


class AutoEncoderVGG(nn.Module):
    channels_in = EncoderVGG.channels_in
    channels_code = EncoderVGG.channels_code
    channels_out = DecoderVGG.channels_out

    def __init__(self):
        super(AutoEncoderVGG, self).__init__()

        self.encoder = EncoderVGG()
        self.decoder = DecoderVGG(self.encoder.encoder)

    def forward(self, x):
        code, pool_indices = self.encoder(x)
        x_prime = self.decoder(code, pool_indices)

        return x_prime


if __name__ == '__main__':
    model = AutoEncoderVGG().to('cuda')
    # input1 = torch.randn((2, 3, 512, 512))
    criterion = nn.MSELoss()
    dataset = MyAEDataset(path='./samples')
    dataloader = DataLoader(dataset, batch_size=1)
    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.0001, momentum=0.8)
    model.train()
    loss_log = []
    for epoch in range(1000):
        for idx, (img1, img2) in enumerate(dataloader):
            x, y = img1.clone().detach().to(torch.float), img2.clone().detach().to(torch.float)

            x, y = x.to('cuda'), y.to('cuda')
            out = model(x).to(torch.float).to('cuda')

            loss = criterion(out, y).to('cuda')
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_log.append(loss.item())
            if epoch % 2 == 0:
                print(loss.item())
    pd.DataFrame(loss_log).to_csv('loss_log.csv')
    torch.save(model.state_dict(), 'VGGAE.pt')

    # out = model(input1)(1, 3, 512, 512)
    # encoder_out1 = model.encoder(input1)
    # print(encoder_out1[0].size())
    # encoder_vec = encoder_out1[0]  # (1, 512, 16, 16)
    # print(encoder_vec.flatten(1).size())

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值