解题思路
点分治。
代码:
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<iostream>
#include<cstring>
#include<string>
#include<cstdlib>
#define ll long long
using namespace std;
int hed[100005],nex[200005],lb[200005],cap[200005];
int f[400005][2],g[400005][2],sum[400005];
bool pa[100005];
int n,lo,u,mmin,mmax;
ll ans;
void add(int x,int y,int num){
lo++;
nex[lo]=hed[x];
hed[x]=lo;
lb[lo]=y;
cap[lo]=num;
}
int getsiz(int x,int fa){
int r=0;
for(int i=hed[x];i!=0;i=nex[i])
if(lb[i]!=fa && !pa[lb[i]])
r+=getsiz(lb[i],x);
return r+1;
}
int getu(int x,int fa,int siz){
int r=0,rc;
bool flag=true;
for(int i=hed[x];i!=0;i=nex[i])
if(lb[i]!=fa && !pa[lb[i]]){
rc=getu(lb[i],x,siz);
if(rc*2>siz) flag=false;
r+=rc;
}
rc=siz-r-1;
if(rc*2>siz) flag=false;
if(flag==true) u=x;
return r+1;
}
void dfs(int x,int fa,int v){
mmax=max(mmax,v);
mmin=min(mmin,v);
if(sum[v]) f[v][1]++;
else f[v][0]++;
sum[v]++;
for(int i=hed[x];i!=0;i=nex[i])
if(lb[i]!=fa && !pa[lb[i]]) dfs(lb[i],x,v+cap[i]);
sum[v]--;
}
void solve(int x){
int r=getsiz(x,x);
getu(x,x,r);
int midmaxx=0,midmin=400000;
g[200000][0]=1;
for(int i=hed[u];i!=0;i=nex[i])
if(!pa[lb[i]]) {
mmax=0;mmin=400000;
dfs(lb[i],u,200000+cap[i]);
midmaxx=max(midmaxx,mmax);
midmin=min(midmin,mmin);
ans+=1ll*f[200000][0]*(g[200000][0]-1);
for(int i=mmin;i<=mmax;i++)
ans+=1ll*f[i][0]*g[400000-i][1]+1ll*f[i][1]*g[400000-i][0]+1ll*f[i][1]*g[400000-i][1];
for(int i=mmin;i<=mmax;i++) {g[i][0]+=f[i][0];g[i][1]+=f[i][1];f[i][0]=f[i][1]=0;}
}
pa[u]=1;g[200000][0]=0;
for(int i=midmin;i<=midmaxx;i++) g[i][0]=g[i][1]=0;
for(int i=hed[u];i!=0;i=nex[i])
if(!pa[lb[i]]) solve(lb[i]);
}
int main(){
int xx,yy,zz;
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%d%d%d",&xx,&yy,&zz);
if(zz==0) zz--;
add(xx,yy,zz);add(yy,xx,zz);
}
solve(1);
printf("%lld",ans);
return 0;
}