AutoEncoder-PyTorch

import os
import h5py
import shutil
import torch as t
import numpy as np
import torch.nn as nn
import torchnet as tnt
from torchvision import transforms
from torch.utils import data
from tensorboardX import SummaryWriter
from torchvision.utils import make_grid


class AutoEncode(nn.Module):
    def __init__(self, in_channels=1):
        super(AutoEncode, self).__init__()
        self.encoder = nn.Sequential(*[
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            # nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True),
            # nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True),
        ])

        self.encoder_mlp = nn.Sequential(*[
            nn.Linear(100, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32)
        ])

        self.decoder_mlp = nn.Sequential(*[
            nn.Linear(32, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 100)
        ])

        self.decoder = nn.Sequential(*[
            nn.ConvTranspose2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            # nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True),
            # nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=6, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=32, out_channels=in_channels, kernel_size=6, stride=2, padding=2)
        ])

    def forward(self, _input):
        encoder = self.encoder(_input)
        encoder = self.encoder_mlp(encoder.view(encoder.size()[0], encoder.size()[2] * encoder.size()[3]))

        decoder = self.decoder_mlp(encoder)
        decoder = self.decoder(decoder.view(decoder.size()[0], 1, int(np.sqrt(decoder.size()[1])), int(np.sqrt(decoder.size()[1]))))
        return encoder, decoder


class DataSet():
    def __init__(self, train=True):
        super(DataSet, self).__init__()
        self.train = train
        if self.train:
            self.h5f = h5py.File('train.h5', 'r')
        else:
            self.h5f = h5py.File('val.h5', 'r')
        self.keys = list(self.h5f.keys())
        # self.transform = transforms.Compose([
        #     transforms.ToTensor()
        # ])

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, item):
        return t.from_numpy(np.array(self.h5f[self.keys[item]]))


model = AutoEncode()

device = t.device('cuda')

if os.path.exists('./logs/ae'):
    shutil.rmtree('./logs/ae')
os.makedirs('./logs/ae')
writer = SummaryWriter('./logs/ae')

train_loss = tnt.meter.AverageValueMeter()
train_psnr = tnt.meter.AverageValueMeter()
val_loss = tnt.meter.AverageValueMeter()
val_psnr = tnt.meter.AverageValueMeter()

loss = nn.MSELoss().to(device)
val_loader = data.DataLoader(DataSet(train=False), batch_size=64, shuffle=True)
train_loader = data.DataLoader(DataSet(train=True), batch_size=32, shuffle=False)


optimizer = t.optim.Adam(model.parameters(), lr=1e-4)
model.to(device)


def train():
    step = 0
    for epoch in range(300):
        for _, img in enumerate(train_loader):
            optimizer.zero_grad()
            step += 1
            _, out = model(img.to(device))
            _loss = loss(out, img.to(device))
            _psnr = 10.0 * np.log10(1.0 / _loss.detach().cpu().numpy())
            train_loss.add(_loss.item())
            train_psnr.add(_psnr)
            _loss.backward()
            optimizer.zero_grad()
            if step % 1000 == 0:
                print(train_psnr.value()[0], train_loss.value()[0])
                writer.add_scalar('train_psnr', train_psnr.value()[0], step)
                writer.add_scalar('train_loss', train_loss.value()[0], step)
                writer.add_image('train_input', make_grid(img.detach(), nrow=8, normalize=True, scale_each=True), step)
                writer.add_image('train_output', make_grid(out.detach(), nrow=8, normalize=True, scale_each=True), step)
                train_psnr.reset()
                train_loss.reset()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值