概念
生成网络
尺寸变大(1 → 图片尺寸)。
通道数变大,再逐渐变小(in → max → …… → 3)。
输出用 Tanh 激活。
输出层不用 BN。
判别网络
尺寸变小(图片尺寸 → 1)。
通道数逐渐变大,再变小(3 → …… → max → 1)。
输出用 Sigmoid 激活。
输入层不用 BN。
实验(生成卡通人脸)
数据集:96×96 的卡通人脸。(5 万)
网络结构:
- 判别器:卷积 + 标准化(BN)+ 激活(LeakyReLU)+ Sigmoid。
- 生成器:转置卷积 + 标准化(BN)+ 激活(ReLU)+ Tanh。
优化器:Adam(lr=0.0002, betas=(0.5, 0.999))。
损失函数:二进制交叉熵(BCELoss)。
输出:
- 判别网络:图片为真的概率。
- 生成网络:图片。
数据集
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.imgs = os.listdir(path)
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.imgs[index]))
return self.transform(img)
网络
import torch
from torch import nn
# 判别器
class D_Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 5, 3, 1, bias=False), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128)