题意:
给你一棵树,每个节点上有一个权值且每条边都有一个颜色,两个点的组合是有效的当且仅当这两个点路上所有相邻的边的颜色不同,他们的值就是这条路上所有点的权值的和。
题解:
很明显是树形DP,但是听说树分治好像能做。
他这个颜色的影响我们可以用排序来解决,也就是在改变颜色的时候可以将之前种类的颜色的节点的值加进去。
那么首先我们需要一个DP数组来记录答案。但是不能光光只有一个DP数组,因为有些节点的值是不能传回去的,那么我们需要一个sum数组记录下面能够传回去的路径的值的总和。那么我们对于每一个点它儿子节点之间的计算的时候我们也需要知道这个儿子子树的路径值用了几次,那么我们还需要一个siz数组来记录这个点的子树上有多少能用的节点。
那么sum数组的转移方程:
s
u
m
[
x
]
+
=
s
u
m
[
v
e
c
[
x
]
[
p
r
e
]
.
i
d
]
+
s
i
z
[
v
e
c
[
x
]
[
p
r
e
]
.
i
d
]
∗
w
[
x
]
;
sum[x]+=sum[vec[x][pre].id]+siz[vec[x][pre].id]*w[x];
sum[x]+=sum[vec[x][pre].id]+siz[vec[x][pre].id]∗w[x];
也就是儿子节点可行的路径种类数加上这个节点对于之后的运算会用到多少次乘上它的值
dp数组的状态转移方程:
d
p
[
x
]
+
=
d
p
[
n
e
.
i
d
]
+
s
u
m
[
x
]
∗
s
i
z
[
n
e
.
i
d
]
+
s
u
m
[
n
e
.
i
d
]
∗
s
i
z
[
x
]
;
dp[x]+=dp[ne.id]+sum[x]*siz[ne.id]+sum[ne.id]*siz[x];
dp[x]+=dp[ne.id]+sum[x]∗siz[ne.id]+sum[ne.id]∗siz[x];
也就是当前儿子子树的所有情况与其他儿子子树包括父亲本身的情况的运算。
但是要注意要在之后将不能传回去的sum值减掉,也就是儿子边与父亲边相同的情况。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=3e5+5;
struct node
{
int id,col;
bool operator< (const node& a)const
{
return col<a.col;
}
};
vector<node>vec[N];
ll w[N],dp[N],siz[N],sum[N];
void dfs(int x,int fa,int c)
{
siz[x]++;
sum[x]=w[x];
int pre=0;
for(int i=0;i<vec[x].size();i++)
{
node ne=vec[x][i];
if(ne.id==fa)
continue;
if(i>0&&vec[x][i].col!=vec[x][pre].col)
{
for(;pre<i;pre++)
{
if(vec[x][pre].id==fa)
continue;
sum[x]+=sum[vec[x][pre].id]+siz[vec[x][pre].id]*w[x];
siz[x]+=siz[vec[x][pre].id];
}
}
dfs(ne.id,x,vec[x][i].col);
dp[x]+=dp[ne.id]+sum[x]*siz[ne.id]+sum[ne.id]*siz[x];
}
for(int i=0;i<pre;i++)
{
if(vec[x][i].id==fa)
continue;
if(vec[x][i].col==c)
sum[x]-=(sum[vec[x][i].id]+siz[vec[x][i].id]*w[x]),siz[x]-=siz[vec[x][i].id];
}
if(vec[x][pre].col==c)
return ;
//sum[x]+=(siz[x]-1)*w[x];
for(;pre<vec[x].size();pre++)
{
if(vec[x][pre].id==fa)
continue;
sum[x]+=sum[vec[x][pre].id]+siz[vec[x][pre].id]*w[x];
siz[x]+=siz[vec[x][pre].id];
}
}
int main()
{
int n;
while(~scanf("%d",&n))
{
for(int i=1;i<=n;i++)
dp[i]=sum[i]=siz[i]=0;
for(int i=1;i<=n;i++)
vec[i].clear();
int x,y,col;
for(int i=1;i<=n;i++)
scanf("%lld",&w[i]);
for(int i=1;i<n;i++)
scanf("%d%d%d",&x,&y,&col),vec[x].push_back({y,col}),vec[y].push_back({x,col});
for(int i=1;i<=n;i++)
sort(vec[i].begin(),vec[i].end());
dfs(1,0,0);
printf("%lld\n",dp[1]);
}
return 0;
}