代码
import torch
x = torch.randn(3,4)
print(x)
print(torch.mean(x,dim=1,keepdim=True))
mz_items = x[:,] > torch.mean(x,dim=1,keepdim=True)
print(mz_items)
import numpy as np
mz_items_np = mz_items.numpy()
index = np.argwhere( mz_items_np == True )
print(index)
index[:, 1] = index[:, 1]
vals = x[mz_items]
adj = sp.coo_matrix((vals, (index[:, 0], index[:, 1])), shape=(7, 7))
print(adj)
结果:
tensor([[ 0.4772, -0.7524, 0.1866, 0.2946],
[ 0.3332, 0.6116, 1.0253, -0.8315],
[-0.5561, 0.5256, 0.5922, 1.3521]])
tensor([[0.0515],
[0.2846],
[0.4785]])
tensor([[ True, False, True, True],
[ True, True, True, False],
[False, True, True, True]])
[[0 0]
[0 2]
[0 3]
[1 0]
[1 1]
[1 2]
[2 1]
[2 2]
[2 3]]
(0, 0) 0.47724348
(0, 2) 0.18663067
(0, 3) 0.29461738
(1, 0) 0.33317146
(1, 1) 0.6115727
(1, 2) 1.0253372
(2, 1) 0.5256042
(2, 2) 0.5922486
(2, 3) 1.3520662
补充,要获取最大值的方法:
mz_items = x[:,] == torch.max(x,1,keepdim=True).values
torch.argmax可获取最大值的位置