torch.clamp — PyTorch 2.0 documentation
将输入中的所有元素限制在[min, max]范围内。
>>> a = torch.randn(4)
>>> a
tensor([-1.7120, 0.1734, -0.0478, -0.0922])
>>> torch.clamp(a, min=-0.5, max=0.5)
tensor([-0.5000, 0.1734, -0.0478, -0.0922])
>>> min = torch.linspace(-1, 1, steps=4)
>>> torch.clamp(a, min=min)
tensor([-1.0000, 0.1734, 0.3333, 1.0000])