import os
import h5py
import shutil
import torch as t
import numpy as np
import torch.nn as nn
import torchnet as tnt
from torchvision import transforms
from torch.utils import data
from tensorboardX import SummaryWriter
from torchvision.utils import make_grid
class AutoEncode(nn.Module):
def __init__(self, in_channels=1):
super(AutoEncode, self).__init__()
self.encoder = nn.Sequential(*[
nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2),
nn.ReLU(inplace=True),
# nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
# nn.ReLU(inplace=True),
# nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
# nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
# nn.ReLU(inplace=True),
])
self.encoder_mlp = nn.Sequential(*[
nn.Linear(100, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 32)
])
self.decoder_mlp = nn.Sequential(*[
nn.Linear(32, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 100)
])
self.decoder = nn.Sequential(*[
nn.ConvTranspose2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
# nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
# nn.ReLU(inplace=True),
# nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
# nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=6, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=32, out_channels=in_channels, kernel_size=6, stride=2, padding=2)
])
def forward(self, _input):
encoder = self.encoder(_input)
encoder = self.encoder_mlp(encoder.view(encoder.size()[0], encoder.size()[2] * encoder.size()[3]))
decoder = self.decoder_mlp(encoder)
decoder = self.decoder(decoder.view(decoder.size()[0], 1, int(np.sqrt(decoder.size()[1])), int(np.sqrt(decoder.size()[1]))))
return encoder, decoder
class DataSet():
def __init__(self, train=True):
super(DataSet, self).__init__()
self.train = train
if self.train:
self.h5f = h5py.File('train.h5', 'r')
else:
self.h5f = h5py.File('val.h5', 'r')
self.keys = list(self.h5f.keys())
# self.transform = transforms.Compose([
# transforms.ToTensor()
# ])
def __len__(self):
return len(self.keys)
def __getitem__(self, item):
return t.from_numpy(np.array(self.h5f[self.keys[item]]))
model = AutoEncode()
device = t.device('cuda')
if os.path.exists('./logs/ae'):
shutil.rmtree('./logs/ae')
os.makedirs('./logs/ae')
writer = SummaryWriter('./logs/ae')
train_loss = tnt.meter.AverageValueMeter()
train_psnr = tnt.meter.AverageValueMeter()
val_loss = tnt.meter.AverageValueMeter()
val_psnr = tnt.meter.AverageValueMeter()
loss = nn.MSELoss().to(device)
val_loader = data.DataLoader(DataSet(train=False), batch_size=64, shuffle=True)
train_loader = data.DataLoader(DataSet(train=True), batch_size=32, shuffle=False)
optimizer = t.optim.Adam(model.parameters(), lr=1e-4)
model.to(device)
def train():
step = 0
for epoch in range(300):
for _, img in enumerate(train_loader):
optimizer.zero_grad()
step += 1
_, out = model(img.to(device))
_loss = loss(out, img.to(device))
_psnr = 10.0 * np.log10(1.0 / _loss.detach().cpu().numpy())
train_loss.add(_loss.item())
train_psnr.add(_psnr)
_loss.backward()
optimizer.zero_grad()
if step % 1000 == 0:
print(train_psnr.value()[0], train_loss.value()[0])
writer.add_scalar('train_psnr', train_psnr.value()[0], step)
writer.add_scalar('train_loss', train_loss.value()[0], step)
writer.add_image('train_input', make_grid(img.detach(), nrow=8, normalize=True, scale_each=True), step)
writer.add_image('train_output', make_grid(out.detach(), nrow=8, normalize=True, scale_each=True), step)
train_psnr.reset()
train_loss.reset()
AutoEncoder-PyTorch
最新推荐文章于 2023-11-28 09:07:24 发布