前言
看到这样一行代码
label = torch.zeros_like(pred_label)
label.scatter_(1, batch_data.label.cuda().unsqueeze(dim=1), 1)
解释
假设这里的pred_label的大小为(32,18),32是batch_size,18代表此分类任务有18个类别。
我们创建的label大小也为(32,18),全0。
再假设这里的batch_data.label长下面的样子,它代表batch中,第一个样本的类别是0,第二个样本的类别是13…
然后我们可以看到,我们创建的label:
tensor.scatter_(dim,index,src)这三个参数在这里设置为
- dim=1: 表示按照列进行填充
- index=batch_data.label:表示把batch_data.label里面的元素值作为下标,去下标对应位置(这里的"对应位置"解释为列,如果dim=0,那就解释为行)进行填充
- src=1:表示填充的元素值为1,src也可以是一个跟batch_data.label同样大小的tensor,具体可以看这篇文章
最后经过scatter_我们得到label长下面的样子:
所以应该知道: 当dim设置为1的时候,我们遍历label的每一行,然后去填充该行的一些列
- 哪些列?——去batch_data.label的对应行找(所以这两个tensor的行数在dim=1时,需要相等;如果dim=0,则"需要变换的那个tensor"和“index”的列数需要相等)
- 填充什么东西?——去src那个参数的地方找,如果src就是一个值,那填充的就是那个值,否则src就必须要是一个跟index参数大小一样的一个tensor。
这样我们就可以进一步把pred_label和label送入一个BCEloss function去计算loss了。
self.training_criterion = nn.BCELoss()
loss = self.training_criterion(pred_label, label)