import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator,Generator,initilize
img_size = 64
z_dim = 100
batch_size = 128
lr = 2e-4
img_channel = 1
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
feature_d = 64
feature_g = 64
epoch1 = 5
epoch2 = 30
#data prepare
transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5 for _ in range(img_channel)],[0.5 for _ in range(img_channel)]),
])
train_data = datasets.MNIST('./data',train=True,transform=transforms,download=True)
dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
'''img,_ = next(iter(dataloader))
print(img.shape)'''
#model
D = Discriminator(img_channel,feature_d).to(device)
G = Generator(z_dim,feature_g,img_channel).to(device)
initilize(D)
initilize(G)
#optimizer
optim_d = torch.optim.Adam(D.parameters(),lr = lr,betas=(0.5,0.999))
optim_g = torch.optim.Adam(G.parameters(),lr = lr,betas=(0.5,0.999))
loss_fn = nn.BCELoss()
'''
#pretrain D
for epoch in range(epoch1):
count = len(dataloader)
count1 = len(train_data)
for step,(real,_) in enumerate(dataloader):
noise = torch.randn(batch_size,z_dim,1,1).to(device)
real = real.to(device)
real_d = D(real).view(-1)
loss_real_d = loss_fn(real_d,torch.ones_like(real_d))
fake = G(noise)
fake_d = D(fake.detach()).view(-1)
loss_fake_d = loss_fn(fake_d,torch.zeros_like(fake_d))
loss_d = (loss_fake_d+loss_real_d)/2
optim_d.zero_grad()
with torch.no_grad():
loss_d += loss_d
with torch.no_grad():
loss_epoch_d = loss_d/count
print('Epoch:',epoch)
print('loss is {}'.format(loss_epoch_d))
'''
writer = SummaryWriter('logs')
for epoch in range(epoch2):
count = len(dataloader)
for step,(real,_) in enumerate(dataloader):
real = real.to(device)
real_d = D(real).view(-1)
loss_dr = loss_fn(real_d,torch.ones_like(real_d))
noise = torch.randn(batch_size,z_dim,1,1)
fake = G(noise)
fake_d = D(fake.detach()).view(-1)
loss_df = loss_fn(fake_d,torch.zeros_like(fake_d))
loss_d = (loss_dr+loss_df)/2
optim_d.zero_grad()
loss_d.backward()
optim_d.step()
fake_d2 = D(fake).view(-1)
loss_g = loss_fn(fake_d2,torch.ones_like(fake_d2))
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
with torch.no_grad():
loss_d += loss_d
loss_g += loss_g
with torch.no_grad():
loss_epoch_d = loss_d/count
loss_epoch_g = loss_g/count
writer.add_scalar('D_EPOCH_LOSS',epoch,loss_d)
writer.add_scalar('G_EPOCH_LOSS',epoch,loss_g)
writer.close()
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self,img_channel,feature_d):
super(Discriminator, self).__init__()
#input:(img_channel,64,64)
self.disc = nn.Sequential(
nn.Conv2d(img_channel,feature_d,kernel_size=4,stride=2,padding=1),#32*32
self.block(feature_d,feature_d*2,4,2,1),#16*16
self.block(feature_d*2, feature_d * 4, 4, 2, 1),#8*8
self.block(feature_d*4, feature_d * 8, 4, 2, 1),#4*4
nn.Conv2d(feature_d*8,1,kernel_size=4,stride = 1,padding=0),#1*1
nn.Sigmoid()
)
def block(self,inc,outc,kernel_size,stride,padding):
return nn.Sequential(
nn.Conv2d(inc,outc,kernel_size=kernel_size,stride=stride,padding=padding,bias=False),
nn.BatchNorm2d(outc),
nn.LeakyReLU(0.2)
)
def forward(self,x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self,z_dim,feature_d,img_c):
super(Generator, self).__init__()
#input:noise(N*100*1*1)
self.gen = nn.Sequential(
self.block(z_dim,feature_d*16,4,2,0),#(N,2,2)
self.block(feature_d*16,feature_d*8,4,2,1),#(N,4,4)
self.block(feature_d*8, feature_d*4, 4, 2, 1),#(N,8,8)
self.block(feature_d*4, feature_d*2, 4, 2, 1),#(N,16,16)
nn.ConvTranspose2d(feature_d*2,img_c,4,2,1),
nn.Tanh()
)
def block(self,inc,outc,kernel_size,stride,padding):
return nn.Sequential(
nn.ConvTranspose2d(inc,outc,kernel_size=kernel_size,stride=stride,padding=padding,bias = True),
nn.BatchNorm2d(outc),
nn.ReLU(inplace=True),)
def forward(self,x):
return self.gen(x)
def initilize(modle):
for m in modle.modules():
if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
nn.init.normal_(m.weight.data,0.0,0.02)
'''def test():
N,img_channel,H,W = 8,3,64,64
z_dim = 100
x = torch.randn((N,img_channel,H,W))
D = Discriminator(img_channel,8)
initilize(D)
x = D(x)
assert x.shape == (N,1,1,1)
z = torch.randn((N,100,1,1))
G = Generator(z_dim,8,img_c= 3)
initilize(G)
z = G(z)
assert z.shape == (N,3,64,64)
print('sucess')
test()
'''