链接
http://www.lydsy.com/JudgeOnline/problem.php?id=3697
题解
这是个思路题,对我这样的
zz
来说可能已经接近自己想出来的极限了。
一看统计符合条件的路径条数,肯定是点分治,而且肯定是静态的。
首先把边权变成
1
和
根据点分治的过程,可知重心和路径是一对多的关系,而且一条路径只会对应一个重心,就是说每条路径都只会在搞某一个重心时被统计,而且仅会被统计一次。
考虑某条路径被统计的时候的情形。
它肯定会被分成两段,假设这两段是
题目要求必须找到一个断点,使得断开后两段的权值和都为
0
。
考虑断点的位置,要么在
为了不重复统计,我们先统计在
用一个
以上只是算了断点不在重心的。
最后再加上断点在重心的,这个可以加入一棵子树的时候搞,如果当前点不具有
good
属性,就把
f[0][0]
计入答案。
打完之后交上去发现
WrongAnswer
了,写了个对拍,发现还存在一个端点在重心,另一个端点在子树中的情况没计入。这个好说只要做完一个重心之后把
f[1][0]
计入答案就好了。
代码
//点分治
#include <cstdio>
#include <algorithm>
#define maxn 200010
#define ll long long
#define forp for(ll p=head[pos];p;p=nex[p])if(to[p]^pre and !grey[to[p]])
using namespace std;
ll dist[maxn], deep[maxn], N, size[maxn], G, head[maxn], to[maxn], w[maxn],
nex[maxn], tot, list[maxn], cnt[maxn], f[2][maxn];
bool grey[maxn], good[maxn];
ll ans, sumG;
inline void adde(ll a, ll b, ll v)
{to[++tot]=b;w[tot]=v;nex[tot]=head[a];head[a]=tot;}
inline ll read(ll x=0)
{
char c=getchar();
while(c<48 or c>57)c=getchar();
while(c>=48 and c<=57)x=(x<<1)+(x<<3)+c-48, c=getchar();
return x;
}
ll dfs(ll pos, ll pre)
{
list[++*list]=pos;
size[pos]=1;
good[pos]=(bool)cnt[-dist[pos]+N];
cnt[-dist[pos]+N]++;
forp dist[to[p]]=dist[pos]+w[p], deep[to[p]]=deep[pos]+1,
size[pos]+=dfs(to[p],pos);
cnt[-dist[pos]+N]--;
return size[pos];
}
void findG(ll pos, ll pre, ll sum)
{
if(sum<sumG)G=pos, sumG=sum;
forp findG(to[p],pos,sum+*size-(size[to[p]]<<1));
}
void solve(ll pos)
{
ll i, p, t;
*list=0, deep[pos]=dist[pos]=0, *size=dfs(pos,-1);
for(i=1,sumG=0;i<=*list;i++)sumG+=deep[list[i]];
findG(G=pos,-1,sumG);
grey[G]=1;
for(p=head[G];p;p=nex[p])
if(!grey[to[p]])
{
*list=0, dist[to[p]]=w[p], dfs(to[p],G);
for(i=1;i<=*list;i++)
{
if(good[list[i]])ans+=f[0][-dist[list[i]]+N]+f[1][-dist[list[i]]+N];
else ans+=f[1][-dist[list[i]]+N];
if(!good[list[i]] and !dist[list[i]])ans+=f[0][N];
}
for(i=1;i<=*list;i++)f[good[list[i]]][dist[list[i]]+N]++;
}
ans+=f[1][N];
for(p=head[G];p;p=nex[p])
if(!grey[to[p]])
{
*list=0, dist[to[p]]=w[p], dfs(to[p],G);
for(i=1;i<=*list;i++)f[good[list[i]]][dist[list[i]]+N]--;
}
for(p=head[G];p;p=nex[p])if(!grey[to[p]])solve(to[p]);
}
void init()
{
ll a, b, v, i;
N=read();
for(i=1;i<N;i++)a=read(), b=read(), v=read()?1:-1, adde(a,b,v), adde(b,a,v);
}
int main()
{
init();
solve(1);
printf("%lld",ans);
return 0;
}