需求与方法
场景是我现在有一个 4 * 2 的数据(Tensor),每一行是一个样本,第一列代表分类置信度,第二列是真实标签,我想根据真实标签的值将这个数据分成两个 Tensor,一个只包括 Label=1的样本,一个只包括 Label=0 的样本。
假设 Tensor 内容如下
Tensor a = [[92,1], [23,0],[67,1],[33,0]]
我们预期的结果是两个 Tensor 分别是 [[92, 1], [67, 1]] 与 [[23, 0], [33, 0]]。
代码如下:
a = torch.Tensor([[92,1], [23,0],[67,1],[33,0]])
tensor_label_1 = a[a[:,1]== 1, :]
tensor_label_0 = a[a[:,1]== 0, :]
print(tensor_label_1)
print(tensor_label_0)
# 执行结果如下
# tensor([[92., 1.],
# [67., 1.]])
# tensor([[23., 0.],
# [33., 0.]])
理解
其实这里是用到了 PyTorch 的一些功能,为了理解逻辑,可以分开执行代码的部分。
单独执行 `a[:,1]== 1` 可以得到一个 Bool 类型的 Tensor:
tensor([ True, False, True, False])
而使用 Bool 类型的 Tensor 可以作为 Tensor 的索引使用,所以 `a[a[:,1]== 1, :]` 即表示,对于 Tensor a,按照 True, False, True, False 的顺序选取 行(row),而列的索引是 ";" ,表示选中行的每一列都要放入结果中返回。