CSDNL+CONV

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,sampler,Dataset
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
import time
import numpy as np
from PIL import Image
from torch import optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import datetime
import threading

transform = transforms.Compose(
        [
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )



nz=256
ngf=64
nc=3

class Generator(nn.Module):
    """生成器"""
    def __init__(self):
        super(Generator, self).__init__()
        # 生成器结构
        self.main = nn.Sequential(
            # 输入大小:nz
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 大小:(ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 大小:(ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 大小:(ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 大小:(ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 大小:(nc) x 64 x 64
        )

    def forward(self, input):

        output = self.main(input)

        return output


ndf = 16
class Discriminator(nn.Module):
    """鉴别器"""

    def __init__(self):
        super(Discriminator, self).__init__()

        # 鉴别器的结构
        self.main = nn.Sequential(
            # 输入大小: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        # 与生成器类似哟
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        # 注意输出已经延展成一列的张量了
        return output.view(-1, 1).squeeze(1)


class MyData(Dataset):
    def __init__(self, is_train):
        super(MyData, self).__init__()
        self.root = '\\\\vdinas.ymtc.local\\perdata\\E908724\\Downloads\\dcgan_anime_avatars-master\\data\\'
        self.path = self.root

    def __getitem__(self, item):
        imgs = os.listdir(self.path)
        img = Image.open(self.path + imgs[item])
        img = transform(img)
        return img

    def __len__(self):
        return len(os.listdir(self.path))

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 3, 64, 64)
    return x

if __name__ == '__main__':
    num_epochs = 5
    train_data = MyData(is_train=True)
    data_set = DataLoader(dataset=train_data, batch_size=16, shuffle=True, drop_last=True)
    epochs = 10000
    netD = Discriminator()
    netG = Generator()
    criterion = nn.BCEWithLogitsLoss()
    optimizerD = optim.Adam(netD.parameters(), lr=0.0002)
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
    # 判别器损失
    discriminator_loss = 0
    # 生成器损失
    generator_loss = 0

    label_true = torch.ones(16)
    label_true = Variable(label_true)

    label_fake = torch.zeros(16)
    label_fake = Variable(label_fake)

    fixed_noise = torch.randn(16, 256, 1, 1)

    count = 0
    for i in range(epochs):
        for index, img in enumerate(data_set):
            # 训练真图片loss

            optimizerD.zero_grad()
            output = netD(img)
            d_loss_real = criterion(output, label_true)

            d_loss_real.backward()

            # 训练假图片loss
            noise = fixed_noise.data.copy_(torch.randn(16, 256, 1, 1))
            generated_images = netG(noise).detach()
            d_loss_fake = criterion(netD(generated_images), label_fake)

            d_loss_fake.backward()
            optimizerD.step()
            print(index, ' D ', d_loss_fake.data)

            
            optimizerG.zero_grad()
            noise = fixed_noise.data.copy_(torch.randn(16, 256, 1, 1))
            generated_images = netG(noise)
            g_loss_fake = criterion(netD(generated_images), label_true)
            g_loss_fake.backward()

            optimizerG.step()
            print(index, ' G ', g_loss_fake.data)

            if index % 50 == 0:
                pic = to_img(generated_images.cpu().data)
                if not os.path.exists('./simple_autoencoder'):
                    os.mkdir('./simple_autoencoder')
                save_image(pic, './simple_autoencoder/image_{}_{}.png'.format(str(count), str(index)))
        count += 1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值