题目地址【IN】
写了一遍就过了,ヾ(◍°∇°◍)ノ゙凯森
题意简述
给你一棵 n n n个节点的树,每条边有边权,定义 d i s ( i , j ) dis(i,j) dis(i,j)为点 i i i到 j j j的距离模 2 2 2后的值,问你有多少个三元组 ( i , j , k ) (i,j,k) (i,j,k)满足 1 ≤ i , j , k ≤ n 1\leq i,j,k\leq n 1≤i,j,k≤n且 d i s ( i , j ) = d i s ( j , k ) = d i s ( i , k ) dis(i,j)=dis(j,k)=dis(i,k) dis(i,j)=dis(j,k)=dis(i,k)。
数据范围 n ≤ 1 e 4 , 0 ≤ v a l ≤ 233 n\leq 1e4,0\leq val\leq 233 n≤1e4,0≤val≤233
我们可以发现 ( a + b ) % 2 = ( a % 2 x o r b % 2 ) (a+b)\%2=(a\%2\ xor\ b\%2) (a+b)%2=(a%2 xor b%2)
那么 d i s ( i , j ) dis(i,j) dis(i,j)就变成了 i → j i\rightarrow j i→j的路径权值异或和。
那么,由于 a x o r a = 0 a\ xor\ a=0 a xor a=0,异或是可以相消的,所以我们可以 d f s dfs dfs处理一个数组 d [ i ] d[i] d[i],表示节点 i i i到根(这里默认根为 1 1 1号点)的路径上的异或和,那么 d i s ( i , j ) = d [ i ] x o r d [ j ] dis(i,j)=d[i]\ xor\ d[j] dis(i,j)=d[i] xor d[j],原因如下图:
6 → 1 6\rightarrow 1 6→1的路径上有 1 − 2 , 2 − 3 , 3 − 4 , 4 − 6 1-2,2-3,3-4,4-6 1−2,2−3,3−4,4−6
7 → 1 7\rightarrow 1 7→1的路径上有 1 − 2 , 2 − 3 , 3 − 5 , 5 − 7 1-2,2-3,3-5,5-7 1−2,2−3,3−5,5−7
那么将 d [ 6 ] x o r d [ 7 ] d[6]\ xor\ d[7] d[6] xor d[7],其中重复的边 1 − 2 , 2 − 3 1-2,2-3 1−2,2−3就会被异或掉,然后剩下的就变成了 6 − 4 , 4 − 3 , 3 − 5 , 5 − 7 = d i s ( 6 , 7 ) 6-4,4-3,3-5,5-7=dis(6,7) 6−4,4−3,3−5,5−7=dis(6,7)了。
所以 d i s ( i , j ) = d [ i ] x o r d [ j ] dis(i,j)=d[i]\ xor\ d[j] dis(i,j)=d[i] xor d[j]
那么由于只有 d [ i ] = d [ j ] = d [ k ] d[i]=d[j]=d[k] d[i]=d[j]=d[k]的时候才会成为一个合法的三元组,而 d [ i ] d[i] d[i]的值只有 0 , 1 0,1 0,1,所以统计一下个数,我们令 c n t 0 , c n t 1 cnt_0,cnt_1 cnt0,cnt1分别为 0 , 1 0,1 0,1的个数,那么答案就为 ( c n t 0 ) 3 + ( c n t 1 ) 3 (cnt_0)^3+(cnt_1)^3 (cnt0)3+(cnt1)3(由于 i , j , k i,j,k i,j,k可以交换顺序,所以方案数为个数的三次方)。
代码非常简单:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int M=1e5+10;
struct ss{
int to,last,len;
ss(){}
ss(int a,int b,int c):to(a),last(b),len(c){}
}g[M<<1];
int head[M],cnt;
void add(int a,int b,int c){
g[++cnt]=ss(b,head[a],c&1);head[a]=cnt;
g[++cnt]=ss(a,head[b],c&1);head[b]=cnt;
}
ll rec[2];
void dfs(int a,int b,int val){
++rec[val];
for(int i=head[a];i;i=g[i].last){
if(g[i].to==b) continue;
dfs(g[i].to,a,val^g[i].len);
}
}
ll cude(ll a){return a*a*a;}
int a,b,c,n;
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%d%d%d",&a,&b,&c);
add(a,b,c);
}
dfs(1,0,0);
printf("%lld\n",cude(rec[0])+cude(rec[1]));
return 0;
}