题目大意
给定一颗有n个结点的树,每条边的权值为1或-1。问有多少点对(i,j)(注意点对不存在顺序性),满足i到j的最短路径上能找到一个点k,使得i到k的最短路径权值和为0,k到j的最短路径权值和为0。
n<=100000。
点分治
我们进行点分治。
对于当前的根x,我们统计有多少对经过了x的满足题目要求。
我们可以处理出d[i]表示i到x的权值和,b[i]=1表示i到x路径上可以找到异与i与x的一点k满足d[k]=d[i]。
注意到如果i与j是合法的,d[i]+d[j]=0。
因此我们可以用桶进行统计。
所有b[i]=1且d[i]=0的点都可以与x凑成“合法”点对。
所有d[i]=0的点都可以成为“合法”点对。
注意可能会有出自同一子树的点对满足条件,因此分治下去是减去非法方案。
参考程序
#include<cstdio>
#include<algorithm>
#include<deque>
#include<stack>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const ll maxn=100000*2+10;
deque<ll> dl;
stack<ll> sta;
ll a[maxn],sum[maxn],d[maxn],h[maxn],go[maxn],next[maxn],dis[maxn],s[maxn],pd[maxn],fa[maxn],fi[maxn];
bool b[maxn],bz[maxn],vis[maxn],czy;
ll i,j,k,l,t,n,m,tot,top,ans;
void add(ll x,ll y,ll z){
go[++tot]=y;
dis[tot]=z;
next[tot]=h[x];
h[x]=tot;
}
void dg(ll x){
dl.push_back(x);
vis[x]=1;
a[top=1]=x;
ll now,t;
while (!dl.empty()){
t=h[now=dl.front()];
sta.push(now);
s[now]=1;
dl.pop_front();
while (t){
if (!bz[go[t]]&&!vis[go[t]]){
vis[go[t]]=1;
a[++top]=go[t];
fa[go[t]]=now;
dl.push_back(go[t]);
}
t=next[t];
}
}
while (!sta.empty()){
s[fa[sta.top()]]+=s[sta.top()];
sta.pop();
}
}
void dfs(ll x){
ll i,now;
fo(i,1,top) fi[a[i]]=h[a[i]];
sta.push(x);
vis[x]=1;
while (!sta.empty()){
now=sta.top();
while (fi[now]&&(vis[go[fi[now]]]||bz[go[fi[now]]])) fi[now]=next[fi[now]];
if (!fi[now]){
sta.pop();
if (!sta.empty())pd[d[now]]--;
continue;
}
vis[go[fi[now]]]=1;
d[go[fi[now]]]=d[now]+dis[fi[now]];
b[go[fi[now]]]=0;
if (pd[d[go[fi[now]]]]) b[go[fi[now]]]=1;
pd[d[go[fi[now]]]]++;
sta.push(go[fi[now]]);
}
}
void count(ll sig){
ll i;
fo(i,1,top)
if (b[a[i]]&&d[a[i]]>0) sum[d[a[i]]]++;
fo(i,1,top)
if (d[a[i]]<0) ans=ans+sig*sum[-d[a[i]]];
fo(i,1,top){
if (b[a[i]]&&d[a[i]]>0) sum[d[a[i]]]--;
if (b[a[i]]&&d[a[i]]<0) sum[-d[a[i]]]++;
}
fo(i,1,top)
if (!b[a[i]]&&d[a[i]]>0) ans=ans+sig*sum[d[a[i]]];
fo(i,1,top)
if (b[a[i]]&&d[a[i]]<0) sum[-d[a[i]]]--;
ll cnt=0;
fo(i,1,top)
if (!d[a[i]]) cnt++;
if (cnt){
if (sig==1) cnt--;
ans=ans+sig*cnt*(cnt-1)/2;
}
}
void solve(ll x){
top=0;
dg(x);
ll i,j=x,k=0,t;
fo(i,1,top) vis[a[i]]=0;
if (czy) count(-1);else czy=1;
while (1){
t=h[j];
while (t){
if (!bz[go[t]]&&go[t]!=k&&s[go[t]]>s[x]/2){
k=j;
j=go[t];
break;
}
t=next[t];
}
if (!t) break;
}
d[j]=0;
b[j]=0;
dfs(j);
fo(i,1,top) vis[a[i]]=0;
count(1);
fo(i,1,top)
if (!d[a[i]]&&b[a[i]]) ans++;
t=h[j];
bz[j]=1;
//printf("%lld %lld %lld\n",x,j,ans);
while (t){
if (!bz[go[t]]) solve(go[t]);
t=next[t];
}
}
int main(){
//freopen("D:/in.txt","r",stdin);
scanf("%lld",&n);
fo(i,1,n-1){
scanf("%lld%lld%lld",&j,&k,&t);
if (!t) t=-1;
add(j,k,t);
add(k,j,t);
}
//printf("\n");
solve(1);
//printf("\n");
printf("%lld\n",ans);
}