1、代码
class Discriminator(nn.Module):
def __init__(self, input_dim=256, hidden_dim=256):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dis1 = nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_dim)
self.dis2 = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = F.relu(self.dis1(x))
x = self.dis2(self.bn(x))
x = torch.sigmoid(x)
return x
2、概念和代码解读
(1)初始化方法
对应的代码
def __init__(self, input_dim=256, hidden_dim=256):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dis1 = nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_dim)
self.dis2 = nn.Linear(hidden_dim, 1)
1)类的初始化方法用于构造函数。它接受了两个参数input_dim和hidden_dim。
input_dim指的是输入特征的维度,hidden_dim指的是隐藏层的维度。
2)super(Discriminator, self).__init__()调用父类 nn.Module
的构造函数,初始化父类中的各个成员变量。这是继承类构造函数的一种标准写法。
3)self.input_dim = input_dim
self.hidden_dim = hidden_dim
将传入的 input_dim 和 hidden_dim 参数赋值给类的实例变量
self 是一个对象的实例引用,用于在类的方法中访问该实例的属性和方法。
4)self.dis1 = nn.Linear(input_dim, hidden_dim)
定义一个线性层 dis1,它将输入的维度从 input_dim 映射到 hidden_dim
假设输入是图像的像素值或者某种高维特征,通过隐藏层可以将这些输入特征转换为更抽象、更具判别力的特征。例如,从低级的像素值到高级的模式识别特征。
设置隐藏层 dis1 主要是为了增强模型的表现力,使其能够学习和提取更高级的特征。这有助于提高分类任务的准确性和鲁棒性。在神经网络中,隐藏层是从输入到输出进行复杂特征变换和模式识别的关键环节。通过合理的隐藏层设计,模型能够有效地处理和识别复杂的输入模式。
5)self.bn = nn.BatchNorm1d(hidden_dim)
定义一个批标准化层 bn,用于对第一层的输出进行批标准化。批标准化层有助于加快模型训练速度并提高稳定性。
6)self.dis2 = nn.Linear(hidden_dim, 1)
定义第二个线性层 dis2,它将隐藏层的输出映射到一个单一的标量值(输出维度为 1),用于二元分类任务。
此时记录的输出维度为1,输出单元表示某一个类别的概率(通常为正类),并经过sigmoid函数控制输出在0-1之间。如果输出值接近于1,则认为输入为正类;如果输出值接近于0,则认为输入为负类。这种情况下可以使用BCE损失函数来计算损失。
若记录的输出维度为2,则输出单元为一个二维向量,并经过softmax函数映射。第一个元素表示负类的概率,第二个元素表示正类的概率。这种情况下可以使用交叉熵损失函数来计算损失。
输出为1实现简单,直接输出一个概率值,但只能用于二分类任务,需使用sigmoid函数+BCE损失函数。输出为2实现稍微复杂,输出为二维,可以扩散到多分类任务,不仅限于二分类任务,需使用softmax+crossentropyloss来计算损失。
对于二分类任务,两种都可以。如果只考虑二分类任务,输出维度为1更加简单。如果考虑整体一致性(如分类均使用相同的损失函数),可以考虑输出为2。
(2)前向传播方法
对应的代码
def forward(self, x):
x = F.relu(self.dis1(x))
x = self.dis2(self.bn(x))
x = torch.sigmoid(x)
return x
1)线性变换和ReLU激活
self.dis1(x): dis1 是一个线性层(全连接层),其作用是对输入 x 进行线性变换,通常表示为 ( Wx + b ),其中 ( W ) 是权重矩阵,( b ) 是偏置向量。
F.relu(...): ReLU(Rectified Linear Unit)激活函数,其作用是将所有负值变为0,保留正值不变。ReLU 引入了非线性,使模型能够学习更复杂的模式。
2)批标准化和线性变换
self.bn(x): bn 是批标准化层。批标准化通过标准化每一批数据的输出,使其均值为0,方差为1。这有助于加速训练,并使模型更稳定。
self.dis2(...): dis2 是另一个线性层,将标准化后的特征进一步映射到最终的输出维度。在你的例子中,dis2 的输出维度为1。
3)Sigmoid激活
torch.sigmoid(x): Sigmoid 激活函数将线性层的输出压缩到0到1之间,转化为概率值。这对于判别器来说非常重要,因为我们需要一个概率值来判断输入样本属于哪一类。