Description
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
Input
第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
Output
输出符合采药人要求的路径数目。
Sample Input
7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
Sample Output
1
HINT
对于100%的数据,N ≤ 100,000。
果然状态还没恢复好= =这题都写这么久。。
点分治
就是每次找子树的重心然后递归做子树
题解就偷个懒好了
“这样我们枚举根节点的每个子树。用f[i][0...1],g[i][0...1]分别表示前面几个子树以及当前子树和为i的路径数目,0和1用于区分路径上是否存在前缀和为i的节点。那么当前子树的贡献就是f[0][0] * g[0][0] + Σf [i][0] * g [-i][1] + f[i][1] * g[-i][0] + f[i][1] * g[-i][1],其中i的范围[-d,d],d为当前子树的深度。”、
————HZWER
嗯于是我就不写题解了。。
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
using namespace std;
int n;
int sum;
struct line
{
int s,t;
int x;
int next;
}a[200001];
int head[100001];
int edge;
inline void add(int s,int t,int x)
{
a[edge].next=head[s];
head[s]=edge;
a[edge].s=s;
a[edge].t=t;
a[edge].x=x;
}
int d;
int minx=2100000000,mini;
bool v[100001];
int son[100001];
int dis[100001];
inline void find(int d,int fax)
{
son[d]=0;
int tmp=0;
int i;
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(!v[t]&&t!=fax)
{
find(t,d);
son[d]+=son[t]+1;
tmp=max(tmp,son[t]+1);
}
}
tmp=max(tmp,sum-son[d]-1);
if(tmp<minx||tmp==minx&&d<mini)
{
mini=d;
minx=tmp;
}
}
int dep[100001];
int t[100001];
long long f[200001][2],g[200001][2];
int mxdeep;
inline void dfs(int d,int fax)
{
mxdeep=max(mxdeep,dep[d]);
if(t[dis[d]])
f[dis[d]][1]++;
else
f[dis[d]][0]++;
t[dis[d]]++;
int i;
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(t!=fax&&!v[t])
{
dep[t]=dep[d]+1;
dis[t]=dis[d]+a[i].x;
dfs(t,d);
}
}
t[dis[d]]--;
}
long long ans;
inline void solve(int d)
{
v[d]=true;
g[n][0]=1;
int i,j;
int mx=0;
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(!v[t])
{
dep[t]=1;
dis[t]=n+a[i].x;
mxdeep=1;
dfs(t,0);
mx=max(mx,mxdeep);
ans+=(g[n][0]-1)*f[n][0];
for(j=-mxdeep;j<=mxdeep;j++)
ans+=g[n-j][1]*f[n+j][1]+g[n-j][0]*f[n+j][1]+g[n-j][1]*f[n+j][0];
for(j=n-mxdeep;j<=n+mxdeep;j++)
{
g[j][0]+=f[j][0];
g[j][1]+=f[j][1];
f[j][0]=0;
f[j][1]=0;
}
}
}
for(i=n-mx;i<=n+mx;i++)
{
g[i][0]=0;
g[i][1]=0;
}
for(i=head[d];i!=0;i=a[i].next)
{
int t=a[i].t;
if(!v[t])
{
minx=2100000000;
mini=0;
sum=son[d];
find(t,0);
int root=mini;
solve(root);
}
}
}
int main()
{
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
scanf("%d",&n);
int s,t,x;
int i;
for(i=1;i<=n-1;i++)
{
scanf("%d%d%d",&s,&t,&x);
if(x==0)
x--;
edge++;
add(s,t,x);
edge++;
add(t,s,x);
}
sum=n;
find(1,0);
int d=mini;
memset(v,0,sizeof(v));
solve(d);
printf("%lld\n",ans);
return 0;
}