1、torch.where()
torch.where(condition, x, y)---->tensor
- 功能:就是个三目运算符,第一个参数为条件,这个条件为真则为x,否则为y
例子:
a = torch.ones(2,2)
b = torch.zeros_like(a)
cond = torch.randn(2,2) # 这个是构成条件的矩阵
cond = tensor([[-1.0151, 0.9591],
[-0.8334, -1.2766]])
torch.where(cond>0.5, a, b) # 当cond里面的元素>0.5时返回a对应位置的元素,否则为b对应位置的元素
# 输出:
tensor([[0., 1.],
[0., 0.]])
2、torch.gather()
torch.gather(input, dim, index)
- 功能:这个作用就是从输入input中,选择dim维度的数据,再利用索引index来找到这些数据。
例子:
input = torch.randn(3,3) # 生成输入
input =
tensor([[-0.1183, -0.7858, -0.6806],
[-1.0750, -0.8349, -0.4402],
[-0.5234, -0.8766, 0.4833]])
index = torch.arange(3).view(-1, 3) # 生成索引
index = tensor([[0, 1, 2]])
# 在0维(也就是行)索引查找input的值
torch.gather(input, dim=0, index=index)
# 输出:
tensor([[-0.1183, -0.8349, 0.4833]])
# 在1维(也就是列)索引查找input的值
torch.gather(input, dim=0, index=index)
# 输出:
tensor([[-0.1183, -0.7858, -0.6806]])
若有用,欢迎点赞,若有错,请指正,谢谢!