import torch
from torch import Tensor
def soft(x: Tensor, T):
#这是一个软阈值函数
#x是输入的变量,T是阈值
X_abs = torch.abs(x)
def Complex_sign(x):
return x / X_abs
def Complex_max(x, b):
return torch.where(x > b, x, b)
soft = Complex_sign(x) * (Complex_max(0, torch.abs(x)-T))
return soft
与MATLAB代码进行过对比,结果完全一致。