题目
试题编号 | 202309-3 |
---|---|
试题名称 | 梯度求解 |
时间限制 | 1s |
内存限制 | 512MB |
问题描述 | ![]() |
输入输出 | ![]() |
样例1 | ![]() |
样例2 | ![]() |
样例3 | ![]() |
题目分析
题目的要求是求解一个表达式中的偏导数。主要的问题是怎么存储这一表达式。我的想法是存储成一棵树的形式。树的节点要存储当前节点的函数值y和导数值yy。
- 所有的 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn以及所有的常数都是树的叶子节点。y自然就是变量的值或者常数的值。而如果是常量节点,那么yy=0;否则看本次要求解的偏导数是不是该节点的变量。如果是则yy=1,否则也视为常数yy=0。
- 树的根节点以及中间的非叶子节点都是
+
,
−
,
∗
+,-,*
+,−,∗。该这一类的节点的y和yy设置如下:
y = c a l c u l a t e ( y l , y r ) , y y = c a l c u l a t e ( y y l , y y r ) y=calculate(y_{l},y_{r}),yy=calculate(yy_{l},yy_{r}) y=calculate(yl,yr),yy=calculate(yyl,yyr)计算规则很简单,题目也给了,就是简单的加减乘的导数计算。 - 要注意的是中间结果要对1e9+7取模。如果结果是负数那么要: r e s u l t + = 1 e 9 + 7 result+=1e9+7 result+=1e9+7。最后输出根节点的yy即可。
AC代码
有些要注意的小问题:
- 常量值可能是-100等,这个**-100和 − - −**要分别处理。可能引发错误把-100认为是减号。
- 建树的时候,按照栈的结构:遇到’+‘,’-‘,’*'。先出栈的节点是右节点,然后是左节点。
- 变量的数量可能大于9,所以不能认为x后面跟的第一个数字就是变量的标号。
#include <iostream>
#include <string.h>
#include <sstream>
#include <stack>
#define rep(i,a,n) for(int i=a;i<n;i++)
#define ll long long
using namespace std;
const int mod = 1e9+7;
int M(ll h){return h<0?h+mod:h;}
int n,m,num;
string bolan,t;
stringstream st,ss;
int a[101];
struct node{
char com;
int x_num;
ll y;
ll yy;
int order;
int left;
int right;
node():order(0),x_num(0),y(0),yy(0),left(0),right(0){}
node(char cc,int _x_num,int _y,int _yy):com(cc),x_num(_x_num),y(_y),yy(_yy){}
}_u[130];
stack<node>q;
node tran(string s){
node u;
ss.clear();
if(isdigit(s[0])||(s[0]=='-'&&s.size()>1)){
ss<<s;
ss>>u.y;
u.com='c';
}
else if(isalpha(s[0])){
s.erase(s.begin());
ss<<s;
ss>>u.x_num;
u.com='x';
}
else{
u.com=s[0];
}
return u;
}
int index=1;
void bfs(int _i){
if(_u[_i].com=='x'){
_u[_i].y=a[_u[_i].x_num];
_u[_i].yy=(_u[_i].x_num==num?1:0);
}
else if(_u[_i].com=='+'){
int _left=_u[_i].left;
int _right=_u[_i].right;
bfs(_left);bfs(_right);
_u[_i].y=_u[_left].y+_u[_right].y;
_u[_i].yy=_u[_left].yy+_u[_right].yy;
_u[_i].y%=mod;_u[_i].yy%=mod;
_u[_i].y=M(_u[_i].y);_u[_i].yy=M(_u[_i].yy);
}
else if(_u[_i].com=='-'){
int _left=_u[_i].left;
int _right=_u[_i].right;
bfs(_left);bfs(_right);
_u[_i].y=_u[_left].y-_u[_right].y;
_u[_i].yy=_u[_left].yy-_u[_right].yy;
_u[_i].y%=mod;_u[_i].yy%=mod;
_u[_i].y=M(_u[_i].y);_u[_i].yy=M(_u[_i].yy);
}
else if(_u[_i].com=='*'){
int _left=_u[_i].left;
int _right=_u[_i].right;
bfs(_left);bfs(_right);
_u[_i].y=_u[_left].y*_u[_right].y;
_u[_i].yy=(_u[_left].y*_u[_right].yy)%mod+(_u[_left].yy*_u[_right].y)%mod;
_u[_i].y%=mod;_u[_i].yy%=mod;
_u[_i].y=M(_u[_i].y);_u[_i].yy=M(_u[_i].yy);
}
}
int main(){
cin>>n>>m;
getchar();
getline(cin,bolan);
st<<bolan;
while(st>>t){
_u[index]=tran(t);
_u[index].order=index++;
}
rep(i,0,index){
if(_u[i].com=='+'||_u[i].com=='*'||_u[i].com=='-'){
_u[i].right=q.top().order;q.pop();
_u[i].left=q.top().order;q.pop();
}
q.push(_u[i]);
}
while(m--){
cin>>num;
rep(i,1,n+1)cin>>a[i];
bfs(index-1);
cout<<_u[index-1].yy<<endl;
}
return 0;
}
/*
2 2
x1 x1 x1 * x2 + *
1 2 3
2 3 4
*/