>>> x = torch.ones(3, 3)
>>> x[1].fill_(2)
tensor([ 2., 2., 2.])
>>> x[2].fill_(3)
tensor([ 3., 3., 3.])
>>> x
tensor([[ 1., 1., 1.],
[ 2., 2., 2.],
[ 3., 3., 3.]])
>>> torch.renorm(x, 1, 0, 5)
tensor([[ 1.0000, 1.0000, 1.0000],
[ 1.6667, 1.6667, 1.6667],
[ 1.6667, 1.6667, 1.6667]])
第一行的L1范数是3,不大于5,不处理。
第二行的L1范数是6,大于5,每个元素要除以6,再乘以 5。
第三行的L1范数是9,大于5,每个元素要除以9,再乘以 5。
参考:
https://pytorch.org/docs/stable/generated/torch.renorm.html?highlight=renorm#torch.renorm