团体天梯 L3-023 计算图 (30 分)

L3-023 计算图 (30 分)

“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 f(x​1​​,x​2​​)=lnx​1​​+x​1​​x​2​​−sinx​2​​ 的计算图。

figure.png

现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入x​1​​=2,x​2​​=5,上述计算图获得函数值 f(2,5)=ln(2)+2×5−sin(5)=11.652;并且根据微分链式法则,上图得到的梯度 ∇f=[∂f/∂x​1​​,∂f/∂x​2​​]=[1/x​1​​+x​2​​,x​1​​−cosx​2​​]=[5.500,1.716]。

知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(e​x​​,即编程语言中的 exp(x) 函数)、对数(lnx,即编程语言中的 log(x) 函数)和正弦函数(sinx,即编程语言中的 sin(x) 函数)。

友情提醒:

  • 常数的导数是 0;x 的导数是 1;e​x​​ 的导数还是 e​x​​;lnx 的导数是 1/x;sinx 的导数是 cosx。
  • 回顾一下什么是偏导数:在数学中,一个多变量的函数的偏导数,就是它关于其中一个变量的导数而保持其他变量恒定。在上面的例子中,当我们对 x​1​​ 求偏导数 ∂f/∂x​1​​ 时,就将 x​2​​ 当成常数,所以得到 lnx​1​​ 的导数是 1/x​1​​,x​1​​x​2​​ 的导数是 x​2​​,sinx​2​​ 的导数是 0。
  • 回顾一下链式法则:复合函数的导数是构成复合这有限个函数在相应点的导数的乘积,即若有 u=f(y),y=g(x),则 du/dx=du/dy⋅dy/dx。例如对 sin(lnx) 求导,就得到 cos(lnx)⋅(1/x)。

如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。

输入格式:

输入在第一行给出正整数 N(≤5×10​4​​),为计算图中的顶点数。

以下 N 行,第 i 行给出第 i 个顶点的信息,其中 i=0,1,⋯,N−1。第一个值是顶点的类型编号,分别为:

  • 0 代表输入变量
  • 1 代表加法,对应 x​1​​+x​2​​
  • 2 代表减法,对应 x​1​​−x​2​​
  • 3 代表乘法,对应 x​1​​×x​2​​
  • 4 代表指数,对应 e​x​​
  • 5 代表对数,对应 lnx
  • 6 代表正弦函数,对应 sinx

对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。

题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。

输出格式:

首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。

输入样例:

7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4

输出样例:

11.652
5.500 1.716

模拟求导过程:

思路等代码完事再整理

通过记忆性递归进行优化

#include<iostream>
#include<map>
#include<math.h>
using namespace std;
struct Node {
	int id;				//结点类型
	double value;		//记录输入结点的值
	int preA, preB;		//记录运算符的两(一)个参数结点
	int post;  //记录是否有后继结点,用于寻找出口
}node[50001];
map<int, map<int, map<int, double>>> save;
double calc(const int& nd, const int& key, const int& p) { //第一个参数为结点,第二个参数决定是否求导,第三个参数是对谁求导
	if (save[nd][key][p])
		return save[nd][key][p];
	else {
		switch (node[nd].id) {
		case 0:		//value
			return save[nd][key][p] = (key == 0 ? node[nd].value : (nd == p ? 1 : 0)); break;
		case 1:		//plus
			return save[nd][key][p] = calc(node[nd].preA, key, p) + calc(node[nd].preB, key, p); break;
		case 2:		//minus
			return save[nd][key][p]=calc(node[nd].preA, key, p) - calc(node[nd].preB, key, p); break;
		case 3:		//multiplies
			return save[nd][key][p] = (key ? calc(node[nd].preA, key, p) * calc(node[nd].preB, 0, p) + calc(node[nd].preA, 0, p) * calc(node[nd].preB, key, p) : calc(node[nd].preA, key, p) * calc(node[nd].preB, key, p)); break;
		case 4:		//divides
			return save[nd][key][p]=(key ? exp(calc(node[nd].preA, 0, p)) * calc(node[nd].preA, key, p) : exp(calc(node[nd].preA, key, p))); break;
		case 5:		//ln
			return save[nd][key][p] = (key ? 1 / (calc(node[nd].preA, 0, p)) * (calc(node[nd].preA, key, p)) : log(calc(node[nd].preA, key, p))); break;
		case 6:		//sin
			return save[nd][key][p] = (key ? cos(calc(node[nd].preA, 0, p)) * calc(node[nd].preA, key, p) : sin(calc(node[nd].preA, key, p)));
		}
	}
}
int main() {
	int n, a, b, c, end = 0, flag = 0;
	double db;;
	cin >> n;
	for (int i = 0; i < n; i++) {
		scanf("%d", &node[i].id);
		if (node[i].id == 0)
			scanf("%lf", &node[i].value);
		else if (node[i].id <= 3) {
			scanf("%d%d", &node[i].preA, &node[i].preB);
			node[node[i].preA].post = 1;
			node[node[i].preB].post = 1;
		}
		else {
			scanf("%d", &node[i].preA);
			node[node[i].preA].post = 1;
		}
	}
	while (node[end].post) end++;
	printf("%0.3lf\n", calc(end, 0, -1));
	for (int i = 0; i < n; i++) {
		if (node[i].id == 0) {
			printf("%s%0.3lf", flag ? " " : "", calc(end, 1, i));
			flag = 1;
		}
	}
	return 0;
}

 

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值