zeros = torch.zeros(real_size,CLASSES_COUNT) for k in range(real_size ): zeros[k, batch_label[k]] = 1
更新了。
zeros[torch.arrange(batch_label.size(0)),batch_label]=1
以下是老的。两种方式生成的结果一样。估计速度后者更快吧。
代码解读:
zeros是一个tensor,它的行数是real_size,列数是类别数 CLASSES_COUNT。每一行代表一个样本,每个样本在某一列的值是1,其他是0。所以第一行的代码就是为了全是0的一个tensor。
batch_label是一个列表,其中存放了每行中哪个列值是1,相当于存放了一系列的索引值。元素在batch_label中的索引代表在zeros中的行索引,元素值代表在zeros中列索引。
然后遍历样本数,k行的batch_label[k]置为1。
[0, 4, 2, 1, 1, 4, 1, 0, 2, 0] 这是batch_label
tensor([[1., 0., 0., 0., 0.],这是最终生成的zeros tensor
[0., 0., 0., 0., 1.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.]])