torch.nn.Softshrink
原型
CLASS torch.nn.Softshrink(lambd=0.5)
参数
- lambd (float) – λ \lambda λ 为 Softshrink参数,默认为 0.5, 必须不小于0
定义
SoftShrinkage ( x ) = { x − λ , if x > λ x + λ , if x < − λ 0 , otherwise \text{SoftShrinkage}(x)=\begin{cases} x-\lambda, & \text{if } x > \lambda \\ x+\lambda, & \text{if } x < -\lambda \\ 0, & \text{otherwise} \end{cases} SoftShrinkage(x)=⎩ ⎨ ⎧x−λ,x+λ,0,if x>λif x<−λotherwise
图
代码
import torch
import torch.nn as nn
m = nn.Softshrink()
input = torch.randn(4)
output = m(input)
print("input: ", input)
print("output: ", output)
# input: tensor([ 0.9876, -2.0183, -0.7573, -1.7960])
# output: tensor([ 0.4876, -1.5183, -0.2573, -1.2960])