pytorch下可训练分段函数的写法
这篇博客主要讲如何写一个可训练求导的分段函数,并通过代码验证其可行性
假设我们要实现这样一个分段函数:
F ( a ) = { + 1 a > 2.5 0.25 ∗ a + 0.375 0.5 < a ≤ 2.5 a − 0.5 < a ≤ 0.5 0.25 ∗ a − 0.375 − 2.5 < a ≤ − 0.5 − 1 a ≤ − 2.5 F(a)=\left\{ \begin{array}{rcl} +1 & & {a > 2.5}\\ 0.25*a + 0.375 & & {0.5 < a \leq 2.5}\\ a & & {-0.5 < a \leq 0.5}\\ 0.25*a - 0.375 & & {-2.5 < a \leq -0.5}\\ -1 & & {a \leq -2.5} \end{array} \right. F(a)=⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧+10.25∗a+0.375a0.25∗a−0.375−1a>2.50.5<a≤2.5−0.5<a≤0.5−2.5<a≤−0.5a≤−2.5
首先我们生成一个分布涵盖这5个区间的张量:
import torch
a = torch.tensor([-3, -1.5, 0.1, 2.3, 5])
a.requires_grad = True
第二步,分别求出他们的区间:
b1 = a <= -2.5
b2 = (a <= -0.5) & (a > -2.5)
b3 = (a <= 0.5) & (a > - 0.5)
b4 = (a <= 2.5) & (a > 0.5)
b5 = a > 2.5
运行结果如图所示
符合区间的条件的位置为TRUE,不符合的为FALSE
第三步,五种分段函数的计算:
a1 = -1
a2 = 0.25 * a - 0.375
a3 = a
a4 = 0.25 * a + 0.375
a5 = 1
a1 = a1 * b1 #和对应区间相乘
a2 = a2 * b2
a3 = a3 * b3
a4 = a4 * b4
a5 = a5 * b5
运行结果如图所示
这样,函数计算的值只在自己的区间内有效,其余地方为0;最后我们再将这五个加起来就完成了分段函数的计算:
c = a1 + a2 + a3 + a4 + a5
运行结果如图所示
接下来我们通过反向求导验算一下:
c.backward(torch.tensor([1.,1.,1.,1.,1.]))
backward()中参数的含义,可以参考这篇博客:Pytorch中loss.backward(),loss为矢量时,对参数的理解
ok,我们输出a的梯度试试
验算正确,这样就实现了pytorch下分段函数的可训练