题意: 给定n(n<=300000)个点的一棵树,每条边有一定的颜色(s<=100000),每个点有点权,如果u->v的路径上相邻的边的颜色都不同那么是一条合法
的路径,求所有合法路径的权值和是多少。
题解:这个树形统计题个人认为想法有点难,比赛的时候想了很多都不行唉~
首先对于每一个节点维护两个值:
能从下面节点延伸到当前节点的合法路径的条数;
这些合法路径的和;
然后根据子节点的这些值得到:当前节点为根节点的所有合法路径的权值和 = 之前深搜的所有子节点向上返回的边数之和 * 当前子节点返回的分数
+之前深搜的所有子节点向上返回的分数之和 * 当前子节点返回的边数+之前深搜的所有子节点向上返回的边数之和 * 当前子节点返回的边数 * 当前点的权。
但是这样不能对已经遍历的子树一一遍历,所以需要对子树先按照颜色排序然后就可以“合并”子树了。
Sure原创,转载请注明出处。
#include <iostream>
#include <cstdio>
#include <memory.h>
#include <algorithm>
using namespace std;
const int maxn = 300002;
struct info
{
int u,v,c;
bool operator < (const info &other) const
{
if(u != other.u)
{
return u < other.u;
}
return c > other.c;
}
}E[maxn << 1];
struct node
{
int v,c;
int next;
}edge[maxn << 1];
int head[maxn],facol[maxn];
__int64 jewel[maxn],ans[maxn],cnt[maxn];
bool isleaf[maxn];
int n,idx;
void addedge(int u,int v,int c)
{
edge[idx].v = v;
edge[idx].c = c;
edge[idx].next = head[u];
head[u] = idx++;
return;
}
void swap(int &a,int &b)
{
int tmp = a;
a = b;
b = tmp;
return;
}
void read()
{
for(int i=1;i<=n;i++)
{
scanf("%I64d",&jewel[i]);
}
for(int i=1;i<n;i++)
{
scanf("%d %d %d",&E[i].u,&E[i].v,&E[i].c);
E[i+n-1] = E[i];
swap(E[i+n-1].u , E[i+n-1].v);
}
return;
}
void make()
{
memset(head,-1,sizeof(head));
idx = 0;
sort(E+1,E+2*n-2);
for(int i=1;i<=2*n-2;i++)
{
addedge(E[i].u,E[i].v,E[i].c);
}
facol[1] = -1;
return;
}
__int64 dfs(int st,int pre)
{
__int64 res = 0,curs = 0,curc = 0,tmps = 0,tmpc = 0;
int bj = -1;
ans[st] = jewel[st];
cnt[st] = 1;
isleaf[st] = true;
for(int i=head[st];i != -1;i=edge[i].next)
{
if(edge[i].v == pre) continue;
isleaf[st] = false;
facol[edge[i].v] = edge[i].c;
res += dfs(edge[i].v , st);
if(facol[edge[i].v] != facol[st])
{
ans[st] += ans[edge[i].v] + jewel[st] * cnt[edge[i].v];
cnt[st] += cnt[edge[i].v];
}
if(facol[edge[i].v] != facol[st] || isleaf[edge[i].v])
{
res += ans[edge[i].v] + jewel[st] * cnt[edge[i].v];
}
if(facol[edge[i].v] != bj)
{
curs += tmps;
curc += tmpc;
tmps = ans[edge[i].v];
tmpc = cnt[edge[i].v];
bj = facol[edge[i].v];
}
else
{
tmps += ans[edge[i].v];
tmpc += cnt[edge[i].v];
}
res += curs * cnt[edge[i].v] + curc * ans[edge[i].v] + curc * cnt[edge[i].v] * jewel[st];
}
return res;
}
int main()
{
while(~scanf("%d",&n))
{
read();
make();
printf("%I64d\n",dfs(1,0));
}
return 0;
}