题目大意
一棵包含
n
个节点的树,每条边都有类型
1≤n≤100000
题目分析
两类边,要求总数相等,首先应该想到将边的两种反映在权值上,分别为
+1
,
−1
。那么边数相同即和为
0
。满足条件的路径一定是两条权值为
这种路径求解问题的一般思路是点分治(你打边分治我不拦你),这题也不例外。
我们规定
dist(x,y)
为
x
,
我们考虑当前重心点
x
(注意,
为了避免重复,我们按子树顺序枚举路径的一个端点,从已处理的子树中找出能构成满足条件路径的点。
我们设
fird
为已处理的子树中满足:
dist(x,y)=d且∀f为y的祖先,dist(x,f)≠d
的
y
的个数。设
设当前枚举的结束点为
如果
e
不存在祖先
当然具体细节由读者自己讨论,需要注意重心为中间点的情况。
时间复杂度
O(nlogn2)
。
代码实现
#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch))
{
if (ch=='-')
f=-1;
ch=getchar();
}
while (isdigit(ch))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int N=100000;
const int M=N-1;
const int E=M<<1;
int last[N+1],fa[N+1],dis[N+1],size[N+1];
int tov[E+1],next[E+1],len[E+1];
int fir[N*2+1],sec[N*2+1];
bool unable[N+1];
long long ans;
int tot,n;
void insert(int x,int y,int z)
{
tov[++tot]=y;
len[tot]=z;
next[tot]=last[x];
last[x]=tot;
}
int que[N+1],head,tail;
int search_core(int u)
{
head=0,tail=1;
que[1]=u;
int x;
while (head!=tail)
{
x=que[++head];
size[x]=1;
int i=last[x],y;
while (i)
{
y=tov[i];
if (y!=fa[x]&&!unable[y])
{
fa[y]=x;
que[++tail]=y;
}
i=next[i];
}
}
for (head=tail;head>1;head--)
size[fa[que[head]]]+=size[que[head]];
int core=0,csize=n+1;
for (head=1;head<=tail;head++)
{
x=que[head];
int i=last[x],y;
int maxs=0;
while (i)
{
y=tov[i];
if (y!=fa[x]&&!unable[y])
maxs=max(maxs,size[y]);
i=next[i];
}
maxs=max(size[u]-size[x],maxs);
if (maxs<csize)
{
csize=maxs;
core=x;
}
}
return core;
}
int extra[2][N*2+1];
bool exist[N*2+1];
int dfs(int x)
{
int ret=1,i=last[x],y;
extra[exist[dis[x]+n]][dis[x]+n]++;
if (exist[dis[x]+n]||!exist[dis[x]+n]&&!dis[x])
ans+=fir[-dis[x]+n]+sec[-dis[x]+n];
else
ans+=sec[-dis[x]+n];
if (!dis[x]&&exist[n])
ans++;
bool rec=exist[dis[x]+n];
exist[dis[x]+n]=true;
while (i)
{
y=tov[i];
if (!unable[y]&&y!=fa[x])
{
fa[y]=x;
dis[y]=dis[x]+len[i];
ret=max(ret,dfs(y)+1);
}
i=next[i];
}
exist[dis[x]+n]=rec;
return ret;
}
void calc(int x)
{
x=search_core(x);
int i=last[x],y,maxs=0;
while (i)
{
y=tov[i];
if (!unable[y])
{
fa[y]=x;
dis[y]=len[i];
int deep=dfs(y);
for (int j=n-deep;j<=n+deep;j++)
{
fir[j]+=extra[0][j];
sec[j]+=extra[1][j];
extra[0][j]=extra[1][j]=0;
exist[j]=false;
}
maxs=max(maxs,deep);
}
i=next[i];
}
for (i=n-maxs;i<=n+maxs;i++)
fir[i]=sec[i]=0;
unable[x]=true;
i=last[x];
while (i)
{
y=tov[i];
if (!unable[y])
calc(y);
i=next[i];
}
}
int main()
{
freopen("yinyang.in","r",stdin);
freopen("yinyang.out","w",stdout);
n=read();
for (int i=1,x,y;i<n;i++)
{
x=read(),y=read();
bool z=read();
insert(x,y,z?1:-1);
insert(y,x,z?1:-1);
}
calc(1);
printf("%lld\n",ans);
fclose(stdin);
fclose(stdout);
return 0;
}