看完题一开始想用floyd做,然后看到数据范围瞬间懵逼;
然后想到用LCA,开始现学,然后发现比赛结束了;
官方题解说用树链。。。还是懵逼;
直到看到http://blog.csdn.net/wjw1340/article/details/77484391;的思路
woc怎么那么水;
题意:给你一棵树,每个节点都有一定值;然后m个询问,问从u节点到v节点的最短路上满足条件的值的和;
思路:首先就是建树没得说,然后设1为根节点,然后求出每一节点的父亲节点和自己所在的层数;等询问的时候就能通过两点的层数判断了(当然首先要判断u,v两点的值是否满足条件,满足就加上它),如果u的层数大于v的层数就操作u:1将u变化为u的父亲节点,然后check(u);再判断u,v的层数,直到两者层数相等;如果v的层数大于u的层数操作同理;
当两者层数相同了,就可以判断u和v是否同一点了,如果不同,就一起操作往上走和加上点值;直到他们相等之后;减去该点值,因为在操作中都是事先先加了该点的值;
然后将最终的解存入数组,最后一起输出就行;
注意要long long。。。
#include<bits/stdc++.h>
using namespace std;
#define maxn 500001
#define ps push_back
#define debug1 cout<<"1"<<endl;
#define debug2 cout<<"2"<<endl;
#define debug3 cout<<"3"<<endl;
int dp[maxn],fa[maxn];
int val[maxn];
bool vis[maxn];
long long arr[maxn];
vector<int> tre[maxn];
int n,m;
int read(int &n)
{
char ch=' ';int q=0,w=1;
for(;(ch!='-')&&((ch<'0')||(ch>'9'));ch=getchar());
if(ch=='-')w=-1,ch=getchar();
for(;ch>='0' && ch<='9';ch=getchar())q=q*10+ch-48;
n=q*w; return n;
}
void dfs(int root,int father,int deep)
{
vis[root]=1;
dp[root]=deep;
fa[root]=father;
for(int i=0;i<tre[root].size();i++)
{
if(vis[tre[root][i]]==0)
dfs(tre[root][i],root,deep+1);
}
return;
}
int check(int id,int a,int b)
{
if(val[id]<=b&&val[id]>=a)
return val[id];
return 0;
}
long long lac(int u,int v,int a,int b)
{
long long ans=0;
ans+=check(u,a,b);
ans+=check(v,a,b);
while(dp[u]<dp[v])
{
//debug1
v=fa[v];
ans+=check(v,a,b);
}
while(dp[u]>dp[v])
{
//debug2
u=fa[u];
ans+=check(u,a,b);
}
while(u!=v)
{
//debug3
u=fa[u];
ans+=check(u,a,b);
v=fa[v];
ans+=check(v,a,b);
}
ans-=check(v,a,b);
return ans;
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
int u,v,a,b;
for(int i=1;i<=n;i++)
{
read(val[i]);
vis[i]=0;
tre[i].clear();
}
for(int i=1;i<n;i++)
{
read(u);read(v);
tre[u].ps(v);
tre[v].ps(u);
}
dfs(1,-1,1);
// for(int i=1;i<=n;i++)
// cout<<i<<": "<<dp[i]<<" "<<fs[i]<<endl;
//
// read(u);read(v);read(a);read(b);
// int ans=lac(u,v,a,b);
// printf("%d",ans);
for(int i=1;i<=m;i++)
{
read(u);read(v);read(a);read(b);
arr[i]=lac(u,v,a,b);
// printf(" %d",ans);
}
printf("%lld",arr[1]);
for(int i=2;i<=m;i++)
printf(" %lld",arr[i]);
printf("\n");
}
}