论文传送门:Generative Adversarial Nets
class Discriminator(nn.Module): # 定义判别器
def __init__(self, img_size=(28, 28)): # 初始化方法
super(Discriminator, self).__init__() # 继承初始化方法
self.img_size = img_size # 图片尺寸,默认单通道灰度图
self.linear1 = nn.Linear(self.img_size[0] * self.img_size[1], 512) # linear映射
self.linear2 = nn.Linear(512, 256) # linear映射
self.linear3 = nn.Linear(256, 1) # linear映射
self.leakyrelu = nn.LeakyReLU(0.2, inplace=True) # leakyrelu激活函数
self.sigmoid = nn.Sigmoid() # sigmoid激活函数,将输出压缩至(0,1)
def forward(self, x): # 前传函数
x = torch.flatten(x, 1) # 输入图片从三维压缩至一维特征向量,(n,1,28,28)-->(n,784)
x = self.linear1(x) # linear映射,(n,784)-->(n,512)