题目链接: http://codeforces.com/contest/1156/problem/D
题意:
给你一棵树,树的边权为 0 0 0 或者 1 1 1 ,定义一个 p a i r < a , b > pair<a,b> pair<a,b> 对答案的贡献 + 1 +1 +1, 当且仅当树上从 a a a 到 b b b 的路径上不存在经过了边权 1 1 1 之后又经过边权 0 0 0 的情况。一对数 < a , b > <a,b> <a,b> 和 < b , a > <b,a> <b,a> 不相同。
做法:
很容易想到树形dp ,但是状态的转移让我想了很久…(毕竟对树形dp不太熟练)。然后又看到了大佬的神奇的并查集做法。
我们把
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 的值定义为,到达点
i
i
i 的时候处于状态
j
j
j 的方案数。
状态
0
−
>
0 ->
0−> 到达这个点的时候子树路径中全部为连续
0
0
0 的方案数。
状态
1
−
>
1 ->
1−> 到达这个点的时候子树路径中先经过了连续
0
0
0 后经过连续
1
1
1 的方案数。
状态
2
−
>
2 ->
2−> 到达这个点的时候子树路径中先经过了连续
1
1
1 后经过连续
0
0
0 的方案数。
状态
3
−
>
3 ->
3−> 到达这个点的时候子树路径中全部为连续
1
1
1 的方案数。
我们将对当前节点 u u u 对答案贡献的值分两个部分计算,一个是经过点 u u u 的,另一部分是以 u u u 为端点的。
我们先说经过点
u
u
u 的。
那么假如我们现在的这条边权为
0
0
0 ,那么现在得到的
v
v
v的值里状态
0
0
0可以用来更新
u
u
u 的0,状态
23
23
23 可以用来更新
u
u
u 的
1
1
1。(至于为什么可以想象一下)。 边权为
1
1
1 的话可以类似得到值。每次找完一个儿子之后就计算一次,这样就可以不漏掉所有的情况。
以 u u u 为端点的情况就可以在全部做完之后对其进行统计。(注意状态 1 1 1和 2 2 2 都是要加的)。
并查集做法非常的有意思。把和该点全 0 0 0 相连的点的 s i z e size size 乘上全 1 1 1 的,再减一,(大概知道这样的话就是把从 0 0 0 到 1 1 1 的所有情况都考虑了并且减掉自己到自己的情况), 1 1 1到 1 1 1和 0 0 0到 0 0 0的情况在端点可以被全部统计进去。
代码
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++)
#define pb push_back
#define fi first
#define se second
#define rep_e(i,u) for(int i=head[u];~i;i=nex[i])
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int maxn=200005;
const int maxm=400005;
const int MAX = 1e9+7;
int n,head[maxn],to[maxm];
int w[maxm],nex[maxm],cnt;
ll ans,dp[maxn][4];
void add(int u,int v,int va){
to[cnt]=v;nex[cnt]=head[u];
w[cnt]=va; head[u]=cnt++;
}
void dfs(int u,int f){
ll tmp[4];
rep_e(i,u){
int v=to[i],x=w[i];
if(v==f) continue;
dfs(v,u);
memset(tmp,0,sizeof(tmp));
if(x==1){
tmp[1]+=dp[v][0]+dp[v][1];
tmp[3]+=dp[v][3]+1;
}
else{
tmp[0]+=dp[v][0]+1;
tmp[2]+=dp[v][3]+dp[v][2];
}
ans+=dp[u][0]*(2*tmp[0]+tmp[2]+tmp[3]);
ans+=dp[u][1]*tmp[3];
ans+=dp[u][2]*tmp[0];
ans+=dp[u][3]*(2*tmp[3]+tmp[0]+tmp[1]);
rep(j,0,3) dp[u][j]+=tmp[j];
}
ans+=2*(dp[u][0]+dp[u][3])+dp[u][1]+dp[u][2];
}
int main() {
memset(head,-1,sizeof(head));
scanf("%d",&n);
rep(i,1,n-1){
int u,v,x;
scanf("%d%d%d",&u,&v,&x);
add(u,v,x); add(v,u,x);
}
dfs(1,-1);
printf("%lld\n",ans);
return 0;
}