图像修复:使用pytorch实现context encoders

GAN网络在最近这些年逐渐有了起色,这篇文章是复现了2016年的第一篇基于GAN去完成图像修复的文章,主要以学习为主。
参考文献:Context encoders: Feature learning by inpainting
参考github项目:https://github.com/BoyuanJiang/context_encoder_pytorch
参考其他资料:https://blog.csdn.net/EstherWjj/article/details/120765518 等


Context encoders 文献

作为第一篇使用GAN在图像修复上,必然有他的杰出贡献和不足。文章的不足之处还是可以明见的,比如训练效果和在训练中只能预测有规律部分缺失(文章使用中心部分进行缺失训练)
贡献:
该文章第一次使用了Encoder-Decoder模式,将输出图片进行编码,再通过得到编码进行解码返回图片
在这里插入图片描述
将中心区域当做输入,会考虑到两个问题,一个是生成的是否能当成图片,另一个是该生成是否能拟合全局图像。文献中采用双损失函数(BCELoss,MSELoss)去耦合结果从而保证填补缺失区域同时保持上下文的稳定性。

网络结构等部分请阅读原文


代码使用vscode(因为vscode 可以通过#%%实现jupyter 一些功能),pytorch 为3.9。

数据集

这里提供一种方法,只是为了学习实现获取。将paris里的数据放入本地路径\context_encoders\dataset\train\streetview即可。

代码

代码通过一些简化为主要的三部分,model.py, train.py, test.py

model.py

该模型主要是在生成网络中,输入128*128的遮挡图,通过了5层编码来达到文章所说的瓶颈(nBottleneck),在通过5层解码来得到生成图像。其中opt参数在train.py 中有具体参数介绍,使用在model 中主要是通过激活函数中inplace 使得网络图像参数可以自行更替。

from pickle import TRUE
from turtle import forward
import torch
import torch.nn as nn

class _netG(nn.Module):
    def __init__(self,opt):
        super(_netG,self).__init__()
        self.main =nn.Sequential(
            nn.Conv2d(opt.nc,opt.nef,4,2,1,bias=False), #3,64
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.nef,opt.nef,4,2,1,bias=False),
            nn.BatchNorm2d(opt.nef),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.nef,opt.nef*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.nef*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.nef*2,opt.nef*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.nef*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.nef*4,opt.nef*8,4,2,1,bias=False),
            nn.BatchNorm2d(opt.nef*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.nef*8,opt.nBottleneck,4,bias=False),
            nn.BatchNorm2d(opt.nBottleneck),
            nn.LeakyReLU(0.2,inplace=True),
            nn.ConvTranspose2d(opt.nBottleneck,opt.ngf*8,4,1,0,bias=False),
            nn.BatchNorm2d(opt.ngf*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf*2),
            nn.ReLU(True), 
            nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf),
            nn.ReLU(True),        
            nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
            nn.Tanh() 
        )
    def forward(self,input):
        output = self.main(input)
        return output

class _netlocalD(nn.Module):
    def __init__(self,opt):
        super(_netlocalD,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
            nn.Sigmoid()
        )
    def forward(self,input):
        output = self.main(input)
        return output.view(-1,1)

train.py

部分参数

—dataset 指定训练数据集
—dataroot 指定数据集下载路径或者已经存在的数据集路径
—workers 进行数据预处理及数据加载使用进程数
—batchSize 一次batch进入模型的图片数目
—imageSize 原始图片重采样进入模型前的大小
—nz  初始噪音向量的大小(Size of latent zz vector)
—ngf 生成网络中基础feature数目
—ndf 判别网络中基础feature数目 
—netG 指定生成网络路径
—netD 指定判别网路径
—niter网络训练过程中epoch数目
—lr  初始学习率
—beta1 使用Adam优化算法中的β1β
-nef  第一个卷积层的滤波器数量
-overlapPred 步长(stride)小于卷积核的边长,出现卷积核与原始输入矩阵作用范围在区域上的重叠(overlap),一致时,不会出现重叠现象。
-nBottleneck编码器nBottleneck的数量
—cuda 指定使用GPU进行训练
—outf 模型输出图片的保存路径
—manualSeed 指定生成随机数的seed
-wtl2 L2损失函数的权重0.998
-wtlD 对抗损失的函数0.001

导包

#%%
import argparse
from email.mime import image
import random
from re import T
from tkinter import Label
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import os
from model import _netG,_netlocalD
import torch.backends.cudnn as cudnn
import torch.utils.data
from torch.autograd import Variable
from torchvision import utils as vutils
import torch.utils.data

部分参数

这里使用parser,主要是为了方便了一些参数处理和在model中图像变化参数的使用

#%%
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device." .format(device))
#%%
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',  default='streetview', help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--batchSize', type=int, default=64, help='批训练数据量')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--dataroot',  default='dataset/train', help='path to dataset')
parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the input image to network')
parser.add_argument('--workers', type=int, help='使用的进程数量', default=0) #cpu设置成0, 源代码是2
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--overlapPred',type=int,default=4,help='overlapping edges')

parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
#在网络中更替的参数
parser.add_argument('--nef',type=int,default=64,help='of encoder filters in first conv layer')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--nBottleneck', type=int,default=4000,help='文中提到的瓶颈')
opt = parser.parse_args(args=[])

输出文件夹

结果都输出在该文件夹中

#%% 输出文件夹
try:
    os.makedirs("result/train/cropped")
    os.makedirs("result/train/real")
    os.makedirs("result/train/recon")
    os.makedirs("model")
except OSError:
    pass

固定训练参数

#%% 固定训练参数来保证训练结果
if opt.manualSeed is None:
    opt.manualSeed = int(random.randint(1,10000)) #不加强转类型出来是tuple
print(f'随机种子:{opt.manualSeed}')
random.seed(opt.manualSeed)  #固定随机种子
torch.manual_seed(opt.manualSeed)
if device !="cpu":  # 待更新
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark=True #cuDNN来衡量自己库里面的多个卷积算法的速度,然后选择其中最快的那个卷积算法。

数据处理

这里因为只使用了streetview 中图像,若有更多数据集,请参考github 中处理

#%% 处理数据
#对streetview 里面文件处理
transform = transforms.Compose([transforms.Scale(opt.imageSize), #调整图像,缩放
                                transforms.CenterCrop(opt.imageSize), # 剪出中心区域
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = dset.ImageFolder(root=opt.dataroot,transform=transform)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset,batch_size=opt.batchSize,shuffle=True,num_workers=opt.workers)

初始化网络

#%% 初始化网络,主要是因为其可以加快梯度下降收敛的速度,并且尽量的使其收敛于全局最优。
def weights_init(m):  
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0,0.02)
    elif classname.find('batchNorm') != -1:
        m.weight.data.normal_(1.0,0.02)
        m.bias.data.fill(0)

加载网络

#%% 加载网络
resume_epoch=0 #运行次数起点
netG = _netG(opt).to(device)
netG.apply(weights_init) #运用到网络所有参数

#如果有训练好的模型参数可以直接加载
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage,location:storage)['state_dict'])
    resume_epoch = torch.load(opt.netG)['epoch']
# print(netG)

netD = _netlocalD(opt).to(device)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netD)['epoch']
print(netD)

损失函数

这里两个损失函数对应上文提到的两个问题

#%% 损失函数
criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

定义参数和优化器

#%% 定义参数
input_real = torch.FloatTensor(opt.batchSize,3,opt.imageSize,opt.imageSize)
input_cropped = torch.FloatTensor(opt.batchSize,3,opt.imageSize,opt.imageSize)
label = torch.FloatTensor(opt.batchSize)
real_label = 1 
fake_label = 0
real_center = torch.FloatTensor(opt.batchSize,3,int(opt.imageSize/2), int(opt.imageSize/2))

# Variable就是变量的意思,区别于int变量,它是一种可以变化的变量,符合了反向传播,参数更新的属性
# tensor不能反向传播,variable可以反向传播
input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)
real_center = Variable(real_center)

#%% 优化器
optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.002, betas=(0.5, 0.999))

训练网络

#%% 训练网络
#opt.niter = 100
overlapL2Weight = 10
wtl2 = 0.998
for epoch in range(resume_epoch, opt.niter):
    for i,data in enumerate(dataloader,0):
        real_data, _ = data
        #取得中心位置的数据,这里相当于取得图片中间部分,后面的同理
        real_center = real_data[:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4 + opt.imageSize/2)]
        batch_size = real_data.size(0)
        input_cropped.resize_(real_data.size()).copy_(real_data)
	#这里可能是对图片三个维度进行一些处理
        input_cropped.data[:,0,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*117.0/255.0 - 1.0
        input_cropped.data[:,1,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*104.0/255.0 - 1.0
        input_cropped.data[:,2,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*123.0/255.0 - 1.0

        real_center = real_center.to(device)
        real_data = real_data.to(device)
        input_cropped = input_cropped.to(device)
        # 训练真实数据
        netD.zero_grad()
        label.resize_(batch_size).fill_(real_label)

        output = netD(real_center)
        output = output.squeeze(1).to(device)
    
        errD_real = criterion(output,label)
        errD_real.backward()
        D_x = output.data.mean()

        #训练噪声数据
        fake = netG(input_cropped)
        label.data.fill_(fake_label)
        output = netD(fake.detach())
        output = output.squeeze(1)
        errD_fake = criterion(output,label)
        errD_fake.backward()  #第一个问题
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()

        #更新生成器
        netG.zero_grad()
        label.data.fill_(real_label)
        output = netD(fake)
        output = output.squeeze(1)
        errG_D = criterion(output,label)

        #采用MSE损失解决上下文稳定
        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(wtl2*overlapL2Weight)
        wtl2Matrix.data[:,:,int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred),int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred)] = wtl2

        errG_l2 = (fake-real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix
        errG_l2 = errG_l2.mean()
        errG = (1-wtl2) * errG_D + wtl2 * errG_l2
        errG.backward() #第二个问题

        D_G_z2 = output.data.mean()
        optimizerG.step()
        print(f'[{epoch}/{opt.niter}][{i}/{len(dataloader)}] Loss_D:{errD.data.item():>.4f} Loss_G:{errG_D.data.item():>.4f}/{errG_l2.data.item():>.4f} l_D(x): { D_x:>.4f} l_D(G(z)):{D_G_z1:>.4f}')

#后面就是结果的保存了
        if i % 100 == 0:
            vutils.save_image(real_data,
                    'result/train/real/real_samples_epoch_%03d.png' % (epoch))
            vutils.save_image(input_cropped.data,
                    'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
            recon_image = input_cropped.clone()
            recon_image.data[:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2)] = fake.data
            vutils.save_image(recon_image.data,
                    'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
    torch.save({'epoch':epoch+1,
            'state_dict':netG.state_dict()},
            'model/netG_streetview.pth' )
    torch.save({'epoch':epoch+1,
                'state_dict':netD.state_dict()},
                'model/netlocalD.pth' )

训练结果

这里总共训练100次,训练次数还是太少,不过大概能看出来训练的效果了
第一次:
在这里插入图片描述

第100次
在这里插入图片描述
到这里train 部分就算告一段落了。

test.py

这里我没有进行太多的改进,所以代码还是有点年度的,大家要用的话看着用,把需要测试的图片放在dataset文件下命名为1.png, 可以自行更改。运行是没问题的,至于测试效果嘛。。额,还是训练要多点.,不然还是效果离谱

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from PIL import Image
from model import _netG

def load_image(filename, size=None, scale=None):
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img

def save_image(filename, data):
    img = data.clone().add(1).div(2).mul(255).clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)

parser = argparse.ArgumentParser()
parser.add_argument('--dataset',  default='streetview', help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='dataset/1.png', help="path to netG (to continue training)")
parser.add_argument('--manualSeed', type=int, help='manual seed')

parser.add_argument('--nBottleneck', type=int,default=4000,help='of dim for bottleneck of encoder')
parser.add_argument('--overlapPred',type=int,default=4,help='overlapping edges')
parser.add_argument('--nef',type=int,default=64,help='of encoder filters in first conv layer')
parser.add_argument('--wtl2',type=float,default=0.999,help='0 means do not use else use with this weight')
opt = parser.parse_args(args=[])

netG = _netG(opt)
# netG = TransformerNet()
netG.load_state_dict(torch.load('model/netG_streetview.pth',map_location=lambda storage, location: storage)['state_dict'])
# netG.requires_grad = False
netG.eval()

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

image = load_image('dataset/1.png', opt.imageSize)
image = transform(image)
image = image.repeat(1, 1, 1, 1)

input_real = torch.FloatTensor(1, 3, opt.imageSize, opt.imageSize)
input_cropped = torch.FloatTensor(1, 3, opt.imageSize, opt.imageSize)
real_center = torch.FloatTensor(1, 3, int(opt.imageSize/2), int(opt.imageSize/2))

criterionMSE = nn.MSELoss()

input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
real_center = Variable(real_center)

input_real.data.resize_(image.size()).copy_(image)
input_cropped.data.resize_(image.size()).copy_(image)
real_center_cpu = image[:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2)]
real_center.data.resize_(real_center_cpu.size()).copy_(real_center_cpu)

input_cropped.data[:,0,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*117.0/255.0 - 1.0
input_cropped.data[:,1,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*104.0/255.0 - 1.0
input_cropped.data[:,2,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*123.0/255.0 - 1.0

fake = netG(input_cropped)
errG = criterionMSE(fake,real_center)

recon_image = input_cropped.clone()
recon_image.data[:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2)] = fake.data


save_image('val_real_samples.png',image[0])
save_image('val_cropped_samples.png',input_cropped.data[0])
save_image('val_recon_samples.png',recon_image.data[0])

print('%.4f' % errG.data.item())

结语

至此文献部分和复现代码就结束了,因为主要是以学习为目的,如果中间步骤或者理解有相关改进和问题,欢迎大家批评指正。

  • 5
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值