import torch
if __name__ == '__main__':
x = torch.randn(2,3,dtype=torch.double)
print(x)
x = torch.where(x>0,x,0.)
print(x)
y = torch.randn(1,3)
print(y)
print(torch.cat((x,y),dim=0))
torch.where (condition,x,y)三个参数 ,x,y是形状一样或者可以广播的,condition是满足的条件,当某个位置满足条件的时候,取x相应的位置上的值,否则取y相应位置上的值。
cat拼接操作,二维的情况按行拼接,dim = 0, 沿着列拼接,dim = 1.
tensor([[-0.2799, -0.6307, -0.7366],
[ 0.4416, -0.4934, -0.2466]], dtype=torch.float64)
tensor([[0.0000, 0.0000, 0.0000],
[0.4416, 0.0000, 0.0000]], dtype=torch.float64)
tensor([[-0.1325, -0.6699, -0.1472]])
tensor([[ 0.0000, 0.0000, 0.0000],
[ 0.4416, 0.0000, 0.0000],
[-0.1325, -0.6699, -0.1472]], dtype=torch.float64)