构建判别器模块
import torch.nn as nn
import numpy as np
img_shape = (1,28,28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
discriminator = Discriminator()
discriminator
Discriminator(
(model): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Linear(in_features=512, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.2, inplace=True)
(4): Linear(in_features=256, out_features=1, bias=True)
(5): Sigmoid()
)
)