Description
题目背景
小Y同学是一个高中信息学奥赛里面的蒟蒻,在各位大佬暴切各种dark火题的时候,他还在瑟瑟发抖的学习各种基础知识,在巩固很久的基础知识之后,他终于尝试学习且粗略学懂了他之前一直觉得很高大上的树链剖分,于是他准备尝试去写一写模板题。他点开了做题网站,看到了讨论版一句“萌新学妹刚学oi一个月,树链剖分打错了,求大佬看看”,于是小Y同学立即正义感爆棚点进了这一道题。
题目描述
你梦见了一棵树,这是一棵很茂密的树,因此它有很多的分支。
你注意到这颗树的有 n n n 个果实,每一棵果实都有自己的编号,且标号为 1 1 1 的果实在最上面,像是一个根节点,树上的一个果实 u u u 到另一个果实 v v v 的距离,都恰好是一个整数 c c c ,因为已经固定好了 1 1 1 号果实为根节点,所以这棵树的形状已经确定了,你想知道摘下一颗果实,会连带着把它的子树的果实也给摘下来。
而这个摘下来所得到的贡献为(数字出现的次数*数字)的平方
比如 2 2 2出现了 5 5 5次,那么贡献即为 ( 2 ∗ 5 ) 2 (2*5)^2 (2∗5)2
数字为两个果实之间的距离即树的边权值,边权值的范围为 c c c 。
所以你有 m m m组询问,想知道当前询问的果实连带着它的子树果实被摘下来时的贡献是多少。
Input
第 1 1 1行,三个整数 n , m , c n,m,c n,m,c分别表示树的大小,询问的个数,边权的范围。 ( 1 ≤ n , m , c ≤ 100000 ) (1≤n,m,c≤100000) (1≤n,m,c≤100000)
第 2 − n 2-n 2−n行,每行三个整数 u , v , v i u,v,vi u,v,vi表示从 u u u到 v v v有一条 v i vi vi边权的边。
接下来 m m m行,每行一个整数表示询问的节点。
Output
输出 m m m行,每行一个整数代表子树的权值大小。(保证不会超过 l o n g l o n g long long longlong)
Sample Input 1
11 6 10
1 2 9
2 3 1
3 4 6
2 5 7
4 6 5
5 7 7
7 8 8
7 9 3
7 10 6
3 11 3
5
7
10
6
1
5
Sample Output 1
158
109
0
0
547
158
#include<iostream>
#include<cstring>
#include<map>
const int N = 1e5+10;
using namespace std;
int h[N],e[N<<1],ne[N<<1],idx;
void add(int a,int b){
ne[idx] = h[a],e[idx] = b,h[a] = idx++;
}
int w[N];
int n,m,c;
map<int,int> mp;
int fa[N],siz[N],son[N],top[N];
int L[N],R[N],tim,index[N];
void dfs1(int u,int f){
fa[u] = f;
siz[u] = 1;
L[u] = ++tim;
index[tim] = u;
for(int i = h[u];~i;i=ne[i]){
int y = e[i];
if(y == f)continue;
dfs1(y,u);
siz[u] += siz[y] ;
if(siz[y] > siz[son[u]])son[u] = y;
}
R[u] = tim;
}
void dfs2(int x,int u){
top[x] = u;
if(!son[x])return ;
dfs2(son[x],u);
for(int i = h[x];~i;i=ne[i]){
int y = e[i];
if(y != fa[x] && y != son[x])dfs2(y,y);
}
}
long long ans[N];
int skp;
void get_data(int u){
// cout << " u : "<< u << endl;
//cout << L[u] << ' ' << R[u] << endl;
for(int i = L[u]+1;i<=R[u];i++){
int t = index[i];
int x = w[t],y = mp[w[t]];
ans[u] -= 1LL*x*x*y*y;
mp[w[t]] ++;
y++;
ans[u] += 1LL*x*x*y*y;
if(index[i] == skp){
i = R[index[i]];
ans[u] += ans[skp];
continue;
}
}
// cout << ans[u] << endl;
}
void dsu(int u){
for(int i = h[u];~i;i=ne[i]){
int y = e[i];
if(y == fa[u] || y == son[u])continue;
dsu(y);
}
if(son[u]){
dsu(son[u]);
skp = son[u];
}
get_data(u);
// int x = w[u] , y = mp[w[u]];
// cout << "u:" << u << " x : " << x << " Y : " << y << endl;
// cout << ans[u] << endl;
// ans[u] -= 1LL*x*x*y*y;
// cout << ans[u] << endl;
// if(y>1) y -- , ans[u] += 1LL*x*x*y*y;
if(u == top[u]){
mp.clear();
skp = 0;
}
}
int main(){
cin >> n >> m >> c;
memset(h,-1,sizeof h);
for(int i = 1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b),add(b,a);
w[b] = c;
}
dfs1(1,0);
dfs2(1,0);
dsu(1);
while(m--){
int a;
scanf("%d",&a);
printf("%lld\n",ans[a]);
}
return 0;
}