DCGAN (Pytorch)

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator,Generator,initilize

img_size = 64
z_dim = 100
batch_size = 128
lr = 2e-4
img_channel = 1
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
feature_d = 64
feature_g = 64
epoch1 = 5
epoch2 = 30

#data prepare
transforms = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(img_channel)],[0.5 for _ in range(img_channel)]),
])

train_data = datasets.MNIST('./data',train=True,transform=transforms,download=True)
dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
'''img,_ = next(iter(dataloader))
print(img.shape)'''

#model
D = Discriminator(img_channel,feature_d).to(device)
G = Generator(z_dim,feature_g,img_channel).to(device)
initilize(D)
initilize(G)

#optimizer
optim_d = torch.optim.Adam(D.parameters(),lr = lr,betas=(0.5,0.999))
optim_g = torch.optim.Adam(G.parameters(),lr = lr,betas=(0.5,0.999))
loss_fn = nn.BCELoss()


'''
#pretrain D
for epoch in range(epoch1):
    count = len(dataloader)
    count1 = len(train_data)
    for step,(real,_) in enumerate(dataloader):
        noise = torch.randn(batch_size,z_dim,1,1).to(device)
        real = real.to(device)
        real_d = D(real).view(-1)
        loss_real_d = loss_fn(real_d,torch.ones_like(real_d))
        fake = G(noise)
        fake_d = D(fake.detach()).view(-1)
        loss_fake_d = loss_fn(fake_d,torch.zeros_like(fake_d))
        loss_d = (loss_fake_d+loss_real_d)/2
        optim_d.zero_grad()

        with torch.no_grad():
            loss_d += loss_d
    with torch.no_grad():
        loss_epoch_d = loss_d/count
    print('Epoch:',epoch)
    print('loss is {}'.format(loss_epoch_d))
'''


writer = SummaryWriter('logs')
for epoch in range(epoch2):
    count = len(dataloader)
    for step,(real,_) in enumerate(dataloader):
        real = real.to(device)
        real_d = D(real).view(-1)
        loss_dr = loss_fn(real_d,torch.ones_like(real_d))
        noise = torch.randn(batch_size,z_dim,1,1)
        fake = G(noise)
        fake_d = D(fake.detach()).view(-1)
        loss_df = loss_fn(fake_d,torch.zeros_like(fake_d))
        loss_d = (loss_dr+loss_df)/2
        optim_d.zero_grad()
        loss_d.backward()
        optim_d.step()

        fake_d2 = D(fake).view(-1)
        loss_g = loss_fn(fake_d2,torch.ones_like(fake_d2))
        optim_g.zero_grad()
        loss_g.backward()
        optim_g.step()

        with torch.no_grad():
            loss_d += loss_d
            loss_g += loss_g

    with torch.no_grad():
        loss_epoch_d = loss_d/count
        loss_epoch_g = loss_g/count

    writer.add_scalar('D_EPOCH_LOSS',epoch,loss_d)
    writer.add_scalar('G_EPOCH_LOSS',epoch,loss_g)
writer.close()

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self,img_channel,feature_d):
        super(Discriminator, self).__init__()
        #input:(img_channel,64,64)
        self.disc = nn.Sequential(
            nn.Conv2d(img_channel,feature_d,kernel_size=4,stride=2,padding=1),#32*32
            self.block(feature_d,feature_d*2,4,2,1),#16*16
            self.block(feature_d*2, feature_d * 4, 4, 2, 1),#8*8
            self.block(feature_d*4, feature_d * 8, 4, 2, 1),#4*4
            nn.Conv2d(feature_d*8,1,kernel_size=4,stride = 1,padding=0),#1*1
            nn.Sigmoid()
        )

    def block(self,inc,outc,kernel_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(inc,outc,kernel_size=kernel_size,stride=stride,padding=padding,bias=False),
            nn.BatchNorm2d(outc),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self,z_dim,feature_d,img_c):
        super(Generator, self).__init__()
        #input:noise(N*100*1*1)
        self.gen = nn.Sequential(
            self.block(z_dim,feature_d*16,4,2,0),#(N,2,2)
            self.block(feature_d*16,feature_d*8,4,2,1),#(N,4,4)
            self.block(feature_d*8, feature_d*4, 4, 2, 1),#(N,8,8)
            self.block(feature_d*4, feature_d*2, 4, 2, 1),#(N,16,16)
            nn.ConvTranspose2d(feature_d*2,img_c,4,2,1),
            nn.Tanh()
        )

    def block(self,inc,outc,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(inc,outc,kernel_size=kernel_size,stride=stride,padding=padding,bias = True),
            nn.BatchNorm2d(outc),
            nn.ReLU(inplace=True),)

    def forward(self,x):
        return self.gen(x)


def initilize(modle):
    for m in modle.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)


'''def test():
    N,img_channel,H,W = 8,3,64,64
    z_dim = 100
    x = torch.randn((N,img_channel,H,W))
    D = Discriminator(img_channel,8)
    initilize(D)
    x = D(x)
    assert x.shape == (N,1,1,1)
    z = torch.randn((N,100,1,1))
    G = Generator(z_dim,8,img_c= 3)
    initilize(G)
    z = G(z)
    assert z.shape == (N,3,64,64)
    print('sucess')

test()
'''

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值