一、where函数
1、简介
return a tensor of elements selected from either x or y, depending on condition 返回一个tensor,根据生成规则选择是从x取出还是y取出。 虽然where函数操作可以用for循环代替,但是where是通过GPU加速的,在深度学习中使用where速度会更快。
2.函数
torch.where(cond,x,y)
cond
:生成规则自定义x,y
:两个tensor
3.代码实例
cond = torch. rand( 2 , 2 )
print ( cond)
a = torch. full( [ 2 , 2 ] , 0 )
b = torch. full( [ 2 , 2 ] , 1 )
ans = torch. where( cond> 0.5 , a, b)
print ( ans)
二、gather函数
1、简介
官方文档:
torch.gather(input,dim,index,out=None) -> Tensor
Gather values along an axis specified by dim For a 3D tensor the output is specified by:
out[i][j][k] = input[ index[i][j][k] ][j][k], dim = 0 out[i][j][k] = input[i][ index[i][j][k] ][k], dim = 1 out[i][j][k] = input[i][j][ index[i][j][k] ] , dim = 2 从原tensor中获取指定dim和指定index的数据
2、函数
torch.gather(input,dim,index,out=None)
input
:保持与index.shape
一致,从input中获取数值dim
:从input
中的哪一维度获取index
:保持与input.shape
一致,可以理解为tensor下标,根据index生成最终的tensor
3、代码实例
prob = torch. randn( 4 , 10 )
idx = prob. topk( dim= 1 , k= 3 )
idx = idx[ 1 ]
print ( idx)
label = torch. arange( 10 ) + 100
print ( label)
ans = torch. gather( label. expand( 4 , 10 ) , dim= 1 , index= idx. long ( ) )
print ( ans)