自编码AE 实现图片去马赛克 pytorch

去年老早,曾经写过一个AE的实现,不过写的比较墨迹,不够成熟。今天看到了,就重新写一个。

一.代码

1.全代码名称展示

2.主程序

(一).训练阶段

(1).dataset.py

import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class GetData(Dataset):
    def __init__(self,path0,path1): #得到名字list
        super(GetData,self).__init__()
        self.path0 = path0
        self.path1 = path1
        self.name0_list = os.listdir(self.path0)
        self.name1_list = os.listdir(self.path1)
        self.img2data = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.name0_list)

    def __getitem__(self, index): #按名取图,index对应批次
        self.name0 = self.name0_list[index]
        self.name1 = self.name1_list[index]
        img0 = Image.open(os.path.join(self.path0, self.name0))
        img1 = Image.open(os.path.join(self.path1, self.name1))
        imgdata0 = self.img2data(img0)
        imgdata1 = self.img2data(img1)

        return imgdata0, imgdata1

(2).net.py

import torch
import torch.nn as nn

#卷积
class ResConv2d(torch.nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ResConv2d, self).__init__()

        self.sub_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0), #利用1x1网络
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),
        )

        self.down_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 4, 2, 1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.PReLU()
        )

    def forward(self, x):
        y = self.sub_net(x)
        return self.down_net(x + y) #加残差

#反卷积
class ResConvTranspose2d(torch.nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ResConvTranspose2d, self).__init__()

        self.sub_net = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.ConvTranspose2d(in_channels, in_channels, 3, 1, 1),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),
        )

        self.up_net = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.PReLU(),
        )

    def forward(self, x):
        y = self.sub_net(x)
        return self.up_net(x + y)

#编码
class EncoderNet(torch.nn.Module):
    def __init__(self):
        super(EncoderNet, self).__init__()

        self.sub_net = torch.nn.Sequential(
            ResConv2d(3, 64),  # 32
            ResConv2d(64, 128),  # 16
            ResConv2d(128, 256),  # 8
            ResConv2d(256, 512),  # 4
            ResConv2d(512, 1024),  # 2
            ResConv2d(1024, 20)  # 1
        )

    def forward(self, x):
        return self.sub_net(x)

#解码
class DecoderNet(torch.nn.Module):
    def __init__(self):
        super(DecoderNet, self).__init__()

        self.decorder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(20, 1024, 4, 1, 0),
            ResConvTranspose2d(1024, 512),  # 4
            ResConvTranspose2d(512, 256),  # 8
            ResConvTranspose2d(256, 128),  # 16
            torch.nn.ConvTranspose2d(128, 3, 4, 2, 1)  # 64
        )

    def forward(self,x):
        return self.decorder(x)

(3).train.py

import torch
import net
import dataset
import torch.nn as nn
import os
import shutil
from torch.utils.data import DataLoader
from torchvision.utils import save_image

loss_f = nn.MSELoss()
class MainNet(nn.Module): 
    def __init__(self):
        super(MainNet,self).__init__()

        self.encoder = net.EncoderNet()
        self.decoder = net.DecoderNet()

    def forward(self,x1):
        y = self.encoder(x1)
        y_ = self.decoder(y)
        return y_

    def AELoss(self, y_, x0):
        return loss_f(y_, x0)

#训练
class Trainer(nn.Module):
    def __init__(self):
        super(Trainer,self).__init__()

        self.main_net = MainNet()
        self.main_net.cuda()

        '涉及2种损失,自然就会有对应2个优化器做反向传播'
        ae_parameters = []
        ae_parameters.extend(self.main_net.encoder.parameters())
        ae_parameters.extend(self.main_net.decoder.parameters())
        self.opt_ae = torch.optim.Adam(ae_parameters, lr=1e-3)

    def train(self):
        for epoch in range(10000):
            if os.path.exists('./param0/encoder.pkl'):
                self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
            if os.path.exists('./param0/decoder.pkl'):
                self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))

            self.dataloader = DataLoader(dataset.GetData(path0=r'C:\Users\87419\Desktop\data\64',
                             path1=r'C:\Users\87419\Desktop\data\64_dama'), batch_size=128, shuffle=True)
            count = 0

            '每个epoch内都是遍历5万张图,即dataloader数。每count一次,即每次循环都是处理batchsize张'
            'dataloader长度 = 总张数/批次数 :782 = 50000/64。即loader长度等于每个ecpoch的总count数'
            for img0data, img1data in self.dataloader:

                img0data = img0data.cuda()#把输入的数据加cuda,接下来里面的过程数据自然也就以cuda运行
                img1data = img1data.cuda()

                count += 1
                # print('/')
                # print(len(self.dataloader))

                self.main_net.train()#训练模式
                '每种做梯度更新反向传播,都要重新加载数据!!!'
                y_ = self.main_net(img1data)
                # 生成器VAE损失更新
                aeloss = self.main_net.AELoss(y_, img0data)
                self.opt_ae.zero_grad()
                aeloss.backward()
                self.opt_ae.step()

                if count%25 == 0:
                    self.main_net.eval() #测试模式
                    if os.path.exists('./param0/encoder_tmp.pkl'):
                        shutil.copyfile('./param0/encoder_tmp.pkl', './param0/encoder.pkl')
                    torch.save(self.main_net.encoder.state_dict(), './param0/encoder.pkl')
                    if os.path.exists('./param0/decoder_tmp.pkl'):
                        shutil.copyfile('./param0/decoder_tmp.pkl', './param0/decoder.pkl')
                    torch.save(self.main_net.decoder.state_dict(), './param0/decoder.pkl')

                    save_image(img0data[:1],'./result0/{}_{}_0.jpg'.format(epoch, count))#原图
                    save_image(img1data[:1],'./result0/{}_{}_1.jpg'.format(epoch, count))#遮挡图
                    save_image(y_[:1],'./result0/{}_{}_1_0.jpg'.format(epoch, count)) #生成器vae的输出

                    print('epoch:',epoch,'|','count:',count,'|','|','aeloss:',aeloss.item()/len(self.dataloader))

if __name__ == '__main__':
    Trainer().train()

(二).测试阶段

(1).dataset_test.py

import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class GetData(Dataset):
    def __init__(self,path0): #得到名字list
        super(GetData,self).__init__()
        self.path0 = path0
        self.name0_list = os.listdir(self.path0)
        self.img2data = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.name0_list)

    def __getitem__(self, index): #按名取图,index对应批次
        self.name0 = self.name0_list[index]
        img0 = Image.open(os.path.join(self.path0, self.name0))
        imgdata0 = self.img2data(img0)

        return imgdata0

(2).test.py

import torch
import net
import dataset_test
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torchvision.utils import save_image

class MainNet(nn.Module):
    def __init__(self):
        super(MainNet,self).__init__()
        self.encoder = net.EncoderNet()
        self.decoder = net.DecoderNet()

    def forward(self,x1):
        y = self.encoder(x1)
        y_ = self.decoder(y)
        return y_

class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()

        self.main_net = MainNet()
        self.main_net.cuda()

    def test(self):
        if os.path.exists('./param0/encoder.pkl'):
            self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
        if os.path.exists('./param0/decoder.pkl'):
            self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))

        self.dataloader = DataLoader(dataset_test.GetData(path0=r'C:\Users\87419\Desktop\data\test'))
        count = 0
        self.main_net.eval()  # 测试模式
        for img0data in self.dataloader:
            img0data = img0data.cuda()
            encoded = self.main_net.encoder(img0data)
            decoded = self.main_net.decoder(encoded)
            count += 1
            save_image(decoded, r'C:\Users\87419\Desktop\data\AE_test_result/{}.jpg'.format(count))

if __name__ == '__main__':
    Test().test()

测试效果如下(没仔细训练,只是意思一下):

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值