import pandas as pd
import torch
from torch import nn
from torchvision import models
from torchvision.models import VGG16_BN_Weights
from gen_AED_dataset import MyAEDataset
from torch.utils.data.dataloader import DataLoader
class EncoderVGG(nn.Module):
channels_in = 3
channels_code = 512
def __init__(self):
super(EncoderVGG, self).__init__()
vgg = models.vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)
del vgg.classifier
del vgg.avgpool
self.encoder = self._encodify_(vgg)
def _encodify_(self, encoder):
modules = nn.ModuleList()
for module in encoder.features:
if isinstance(module, nn.MaxPool2d):
module_add = nn.MaxPool2d(kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
return_indices=True)
modules.append(module_add)
else:
modules.append(module)
return modules
def forward(self, x):
pool_indices = []
x_current = x
for module_encode in self.encoder:
output = module_encode(x_current)
# 如果模块是池,有两个输出,第二个是池索引
if isinstance(output, tuple) and len(output) == 2:
x_current = output[0]
pool_indices.append(output[1])
else:
x_current = output
return x_current, pool_indices
class DecoderVGG(nn.Module):
channels_in = EncoderVGG.channels_code
channels_out = 3
def __init__(self, encoder):
super(DecoderVGG, self).__init__()
self.decoder = self._invert_(encoder)
def _invert_(self, encoder):
modules_transpose = []
for module in reversed(encoder):
if isinstance(module, nn.Conv2d):
kwargs = {'in_channels': module.out_channels, 'out_channels': module.in_channels,
'kernel_size': module.kernel_size, 'stride': module.stride,
'padding': module.padding}
module_transpose = nn.ConvTranspose2d(**kwargs)
module_norm = nn.BatchNorm2d(module.in_channels)
module_act = nn.ReLU(inplace=True)
modules_transpose += [module_transpose, module_norm, module_act]
elif isinstance(module, nn.MaxPool2d):
kwargs = {'kernel_size': module.kernel_size, 'stride': module.stride,
'padding': module.padding}
module_transpose = nn.MaxUnpool2d(**kwargs)
modules_transpose += [module_transpose]
return nn.ModuleList(modules_transpose[:-2])
def forward(self, x, pool_indices):
x_current = x
k_pool = 0
reversed_pool_indices = list(reversed(pool_indices))
for module_decode in self.decoder:
if isinstance(module_decode, nn.MaxUnpool2d):
x_current = module_decode(x_current, indices=reversed_pool_indices[k_pool])
k_pool += 1
else:
x_current = module_decode(x_current)
return x_current
class AutoEncoderVGG(nn.Module):
channels_in = EncoderVGG.channels_in
channels_code = EncoderVGG.channels_code
channels_out = DecoderVGG.channels_out
def __init__(self):
super(AutoEncoderVGG, self).__init__()
self.encoder = EncoderVGG()
self.decoder = DecoderVGG(self.encoder.encoder)
def forward(self, x):
code, pool_indices = self.encoder(x)
x_prime = self.decoder(code, pool_indices)
return x_prime
if __name__ == '__main__':
model = AutoEncoderVGG().to('cuda')
# input1 = torch.randn((2, 3, 512, 512))
criterion = nn.MSELoss()
dataset = MyAEDataset(path='./samples')
dataloader = DataLoader(dataset, batch_size=1)
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.0001, momentum=0.8)
model.train()
loss_log = []
for epoch in range(1000):
for idx, (img1, img2) in enumerate(dataloader):
x, y = img1.clone().detach().to(torch.float), img2.clone().detach().to(torch.float)
x, y = x.to('cuda'), y.to('cuda')
out = model(x).to(torch.float).to('cuda')
loss = criterion(out, y).to('cuda')
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_log.append(loss.item())
if epoch % 2 == 0:
print(loss.item())
pd.DataFrame(loss_log).to_csv('loss_log.csv')
torch.save(model.state_dict(), 'VGGAE.pt')
# out = model(input1)(1, 3, 512, 512)
# encoder_out1 = model.encoder(input1)
# print(encoder_out1[0].size())
# encoder_vec = encoder_out1[0] # (1, 512, 16, 16)
# print(encoder_vec.flatten(1).size())
基于VGGBN的自编码器
最新推荐文章于 2024-01-18 09:52:19 发布
本文介绍了使用PyTorch构建了一个结合VGG16的自编码器模型,用于图像编码和解码,通过MyAEDataset进行训练并使用MSELoss进行优化。模型在给定的数据集上进行训练,并记录损失值。
摘要由CSDN通过智能技术生成