一、定义
符号函数的定义如下
二、代码
import torch
# set print option preference
torch.set_printoptions(linewidth=1024, sci_mode=False)
# 1D (vector)
a = torch.tensor([-99.0, -1.0, -0.5, 0.0, 0.0, 0.5, 1.5, 2.5])
# 2D (tensor)
b = torch.tensor([[-99.0, -1.0, -0.5],
[0.0, 0.0, 0.0],
[1.0, 3.0, 10.5],
[-0.1, 0.0, 2.0]])
# use sign function
after_a = torch.sign(a)
after_b = torch.sign(b)
# print
print(a)
print(after_a)
print(b)
print(after_b)
输出
tensor([-99.0000, -1.0000, -0.5000, 0.0000, 0.0000, 0.5000, 1.5000, 2.5000])
tensor([-1., -1., -1., 0., 0., 1., 1., 1.])
tensor([[-99.0000, -1.0000, -0.5000],
[ 0.0000, 0.0000, 0.0000],
[ 1.0000, 3.0000, 10.5000],
[ -0.1000, 0.0000, 2.0000]])
tensor([[-1., -1., -1.],
[ 0., 0., 0.],
[ 1., 1., 1.],
[-1., 0., 1.]])