a = torch.randn(2,3,4,4)
a[1][2][3][1] = np.nan
a[1][2][3][2] = np.nan
result = torch.nonzero(torch.isnan(a)==True)
print(result)
# tensor([[1, 2, 3, 1],
# [1, 2, 3, 2]])
那么就可以查看result.shape[0] 是否 > 0,如果 > 0,就代表a这个tensor里肯定有nan了
a = torch.randn(2,3,4,4)
a[1][2][3][1] = np.nan
a[1][2][3][2] = np.nan
result = torch.nonzero(torch.isnan(a)==True)
print(result)
# tensor([[1, 2, 3, 1],
# [1, 2, 3, 2]])
那么就可以查看result.shape[0] 是否 > 0,如果 > 0,就代表a这个tensor里肯定有nan了