原题
计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 f ( x 1 , x 2 ) = l n x 1 + x 1 x 2 − s i n x 2 f(x_1,x_2)=lnx_1+x_1x_2-sinx_2 f(x1,x2)=lnx1+x1x2−sinx2的计算图。
现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入 x 1 = 1 , x 2 = 5 x_1=1,x_2=5 x1=1,x2=5,上述计算图获得函数值 f ( 2 , 5 ) = l n ( 2 ) + 2 × 5 − s i n ( 5 ) = 11.652 f(2,5)=ln(2)+2×5−sin(5)=11.652 f(2,5)=ln(2)+2×5−sin(5)=11.652;并且根据微分链式法则,上图得到的梯度 ▽ f = [ ∂ f ∂ x 1 , ∂ f ∂ x 2 ] = [ 1 x 1 + x 2 , x 1 − c o s x 2 ] = [ , 5.5001.716 ] \triangledown f=[\frac{\partial f}{\partial x_1},\frac{\partial f}{\partial x_2}]=[\frac1x_1+x_2,x_1-cosx_2]=[,5.5001.716] ▽f=[∂x1∂f,∂x2∂f]=[x11+x2,x1−cosx2]=[,5.5001.716]
知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(即编程语言中的 exp(x) 函数)、对数(lnx,即编程语言中的 log(x) 函数)和正弦函数(sinx,即编程语言中的 sin(x) 函数)。
如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。
输入格式:
输入在第一行给出正整数 N ≤ 5 × 1 0 4 N\le5\times10^4 N≤5×104,为计算图中的顶点数。
以下 N 行,第 i 行给出第 i 个顶点的信息,其中 i=0,1,⋯,N−1。第一个值是顶点的类型编号,分别为:
- 0 代表输入变量
- 1 代表加法,对应 x 1 + x 2 x_1+x_2 x1+x2
- 2 代表减法,对应 x 1 − x 2 x_1-x_2 x1−x2
- 3 代表乘法,对应 x 1 × x 2 x_1\times x_2 x1×x2
- 4 代表指数,对应 e x e^x ex
- 5 代表对数,对应 l n x lnx lnx
- 6 代表正弦,对应 s i n x sinx sinx
对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。
题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。
输出格式:
首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。
输入样例:
7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4
输出样例:
11.652
5.500 1.716
梯度计算
首先考虑计算图上的梯度计算。
注意到顶点储存的是变量和算子,自然而然地,可以考虑用边代表运算过程;具体来说,对于一条从
u
u
u射出,射入
v
v
v的边
⟨
u
,
v
⟩
\langle u,v\rangle
⟨u,v⟩,它记录的是
v
(
u
)
v(u)
v(u)(即将
u
u
u作为变量进行
v
v
v运算)求偏导的结果,也就是
∂
v
∂
u
\frac{\partial v}{\partial u}
∂u∂v
同时,此题中包含的算子是单目或双目的,因此射入一个顶点的边不超过两条。另外,由于
f
:
R
n
↦
R
f:\mathbb R^n\mapsto\mathbb R
f:Rn↦R,因此计算图中一定存在且仅存在一个出度为0的定点(不妨称为输出节点),而入度为0的顶点一定表示变量(不妨称为输入节点)。总上,可以用类似树或k分图的方法处理计算图,换言之,将计算图看作这样一棵特殊的树:以输出节点为根,同一个独立集的深度相同(这是比较粗糙的说法),输入节点是树的叶子节点(同时也是树的全部叶子节点)
一个朴素的想法如下:从输入节点出发,沿着边的方向DFS,每访问一个顶点,便计算一次偏导数值,根据链式法则,在到达输出节点前,计算的偏导数值需要累乘,而同一个变量以此法计算的偏导数值需要累加。
以题目的示意输入
f
(
x
1
,
x
2
)
=
l
n
x
1
+
x
1
x
2
−
s
i
n
x
2
f(x_1,x_2)=lnx_1+x_1x_2-sinx_2
f(x1,x2)=lnx1+x1x2−sinx2为例,可得
∂
f
∂
x
1
=
1
x
1
∗
1
∗
1
+
x
2
∗
1
∗
1
=
1
x
1
+
x
2
∂
f
∂
x
2
=
x
1
∗
1
∗
1
+
c
o
s
x
2
∗
(
−
1
)
=
x
1
−
c
o
s
x
2
\frac{\partial f}{\partial x_1}=\frac1{x_1}*1*1+x_2*1*1=\frac1{x_1}+x_2\\\frac{\partial f}{\partial x_2}=x_1*1*1+cosx_2*(-1)=x_1-cosx_2
∂x1∂f=x11∗1∗1+x2∗1∗1=x11+x2∂x2∂f=x1∗1∗1+cosx2∗(−1)=x1−cosx2
这种想法是可行的,然而我们可以想起从下至上地遍历树,往往是由于题目中没有给出树的全部信息,我们需要建立一棵树(例如最优编码问题),对于已给出树的结构的题目,从上至下遍历往往是更优的方法。
直觉上我们会发现重复访问难以避免:考虑一个储存双目算子的非根节点,假设射入它的边
e
1
e_1
e1被访问
n
1
n_1
n1次,
e
2
e_2
e2被访问
n
2
n_2
n2次,那么射出它的边
e
3
e_3
e3将被访问
n
1
+
n
2
n_1+n_2
n1+n2次,哪怕在访问双目算子前没有出现重复,在访问双目算子时也会发生重复访问。如果所有算子都是双目算子,那么将产生指数级的复杂度。
更好的方法:由上而下地进行DFS,如果递归和计算偏导数的时间复杂度都是
O
(
1
)
O(1)
O(1)的,那么将不会产生重复访问导致的额外时间复杂度。将每次BFS最后访问的叶子节点(也就是输入变量)作为偏导数计算结果的标签,当整个计算图访问结束后将相同标签的偏导数值相加,便得到结果。
事实上,由于
f
:
R
n
↦
R
f:\mathbb R^n\mapsto\mathbb R
f:Rn↦R,使用向后传播方法(Backpropagation)计算梯度能使时间复杂度降低到算子目数之和。
另外地,若
f
:
R
↦
R
n
f:\mathbb R\mapsto\mathbb R^n
f:R↦Rn,由于
f
−
1
:
R
n
↦
R
f^{-1}:\mathbb R^n\mapsto\mathbb R
f−1:Rn↦R,因此可以对应地使用向前传播计算。
代码
首先考虑储存计算图的数据结构,由于算子都是一目或二目的,因此考虑二叉树的储存方法,用四个vector分别储存类型,键值,左子节点,右子节点。主程序如下:
vector<int> type(50000,0),L(50000,-1),R(50000,-1);
vector<double> key(50000,0);
vector<double> val(50000,0),grad(50000,0); //每个节点的函数值和梯度
vector<int> isv(50000,0),isg(50000,0); //标记是否计算了某节点的函数值和梯度
vector<int> var; //变量位置
int main(void){
input(); //读取输入数据,并返回根节点;
getval(root); //计算函数值
getgrad(root); //计算梯度
output(); //按格式输出;
return 0;
}
其中,输入函数为:
int root; //根节点
void input(){
int n;
cin>>n;
vector<int> isroot(n,1); //记录是否为根节点
for(int i=0;i<n;i++){
int r,l,t;
double k=0;
cin>>t;
if(!t){
cin>>k; //读入变量数据
var.push_back(i);
}
else{ //如果是算子
cin>>l;
isroot[l]=0;
L[i]=l;
if(t<4){ //如果是双目算子
cin>>r;
isroot[r]=0;
R[i]=r;
}
}
type[i]=t;
key[i]=k;
}
root=find(isroot.begin(),isroot.end(),1)-isroot.begin(); //记录根节点
}
之后的getval()
函数需要计算每个节点的函数值
double getval(int index){
/*
1: + 4: exp()
2: - 5: log()
3: * 6: sin()
*/
switch(type[index]){
case 0:
val[index]=key[index];
break;
case 1:
val[index]=getval(L[index])+getval(R[index]);
break;
case 2:
val[index]=getval(L[index])-getval(R[index]);
break;
case 3:
val[index]=getval(L[index])*getval(R[index]);
break;
case 4:
val[index]=exp(getval(L[index]));
break;
case 5:
val[index]=log(getval(L[index]));
break;
case 6:
val[index]=sin(getval(L[index]));
break;
}
isv[index]=1;
return val[index];
}
getgrad()
函数用来计算梯度值,需要注意的是用tempgrad
储存临时梯度值。
double tempgrad=1;
void getgrad(int index){
/*
1: + 4: exp()
2: - 5: log()
3: * 6: sin()
*/
switch(type[index]){
case 0:
grad[index]+=tempgrad;
break;
case 1:
tempgrad*=1;
getgrad(L[index]);
tempgrad/=1;
tempgrad*=1;
getgrad(R[index]);
tempgrad/=1;
break;
case 2:
tempgrad*=1;
getgrad(L[index]);
tempgrad/=1;
tempgrad*=-1;
getgrad(R[index]);
tempgrad/=-1;
break;
case 3:
tempgrad*=val[R[index]];
getgrad(L[index]);
tempgrad/=val[R[index]];
tempgrad*=val[L[index]];
getgrad(R[index]);
tempgrad/=val[L[index]];
break;
case 4:
tempgrad*=exp(val[L[index]]);
getgrad(L[index]);
tempgrad/=exp(val[L[index]]);
break;
case 5:
tempgrad*=1/val[L[index]];
getgrad(L[index]);
tempgrad/=1/val[L[index]];
break;
case 6:
tempgrad*=cos(val[L[index]]);
getgrad(L[index]);
tempgrad/=cos(val[L[index]]);
break;
}
}
output()
函数没什么好说的,不过最后一个测试点是
N
=
0
N=0
N=0的输入,不太清楚要求输出什么格式。
void output(){
if(var.size())printf("%.3lf",val[root]);
for(int i=0;i<var.size();i++)
printf("%s%.3lf",i==0?"\n":" ",grad[var[i]]);
}
运行结果: