class drop_model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10),
self.drop = nn.Dropout(0.5)
self.fc2 = nn.Linear(10,10)
def forward(self, data):
res = self.fc1(data)
res = self.drop(res)
return self.fc2(res)
model = drop_model()
input_data = torch.ones(2,10)
model(input_data)

根据错误也可以看出问题在哪里,只是稍微花费了点时间,直白点说:就是因为一个逗号引发的错误。
self.fc1 = nn.Linear(10,10),
把这个后面的逗号去掉也就没事了,真的是好久没写代码了嘛。。。。。
顺带一提,上面的验证小代码的目的是为了说明dropout的层间具体操作,因为好奇一点:
dropout在对于每个batch中的样本都是采用同样的去连接操作呢?还是每个样本都不一样?
去掉逗号之后的输出如下:

这也就验证了,dropout是对每个样本都是不一样的,即使这些样本都在一个batch内。
8242

被折叠的 条评论
为什么被折叠?



