-
2 3 1 0 1 1 1 2 4 1 0 1 1 1 2 1 3
样例输出
-
1 0
描述
小Hi最近参加了一场比赛,这场比赛中小Hi被要求将一棵树拆成3份,使得每一份中所有节点的权值和相等。
比赛结束后,小Hi发现虽然大家得到的树几乎一模一样,但是每个人的方法都有所不同。于是小Hi希望知道,对于一棵给定的有根树,在选取其中2个非根节点并将它们与它们的父亲节点分开后,所形成的三棵子树的节点权值之和能够两两相等的方案有多少种。
两种方案被看做不同的方案,当且仅当形成方案的2个节点不完全相同。
输入
每个输入文件包含多组输入,在输入的第一行为一个整数T,表示数据的组数。
每组输入的第一行为一个整数N,表示给出的这棵树的节点数。
接下来N行,依次描述结点1~N,其中第i行为两个整数Vi和Pi,分别描述这个节点的权值和其父亲节点的编号。
父亲节点编号为0的节点为这棵树的根节点。
对于30%的数据,满足3<=N<=100
对于100%的数据,满足3<=N<=100000, |Vi|<=100, T<=10
输出
对于每组输入,输出一行Ans,表示方案的数量。
统计所形成的三棵子树的节点权值之和能够两两相等的方案,等价于在这树上取两个不同且非根结点,形成三棵子树后子树节点权值之和两两相等。
树型dp求解,每个节点维护res(以这个节点为根的子树节点权值和),cnt(以这个节点为根的子树权值等于sum/3的节点个数)。
一开始能够想到的一个A节点res=sum/3,那么在这棵子树外再找一个res=sum/3的B节点进行组合不就得出一种方案了吗。可是这里面是分两类的
1. B不是A的祖先,那么后来枚举B的时候A又被算了一次。记为2*s1
2. B是A的祖先,其实这种情况是错误的,因为A、B分别取出后,A子树res=sum/3,B子树res=0(因为A子树本来就是B的一部分啊),这种方案是错误的要除去,且记为p
还有一种情况是一个节点res=sum*2/3,那么这个节点与其子树内不包括它自己,任意一个res=sum/3的节点相组合就是一种方案,记为s2。
因为不能选root,所以cnt对res[root]=sum/3情况不予考虑
2*s1+p=cnt[root]^2 - (res[x]=sum/3&&x!=root)
P= (cnt[x]>0&&res[x]=sum/3&&x!=root)
S2= (res[x]=sum*2/3&&x!=root)
最后ans=s1+s2
Dp时维护好数据最后求解即可
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+8;
vector<int> g[maxn];
int root;
ll sum;
ll res[maxn];
ll cnt[maxn];
ll a[maxn];
ll s2,s3,sig;
void dfs(int x,int p){
res[x]=a[x];
cnt[x]=0;
if(g[x].size()<=1){
if(res[x]==sum)cnt[x]=1;
sig+=cnt[x];
//cout<<"xx="<<x<<" "<<sum<<" "<<res[x]<<" "<<cnt[x]<<endl;
return ;
}
for(int i=0;i<g[x].size();i++){
int u=g[x][i];
if(u==p)continue;
dfs(u,x);
res[x]+=res[u];
cnt[x]+=cnt[u];
}
if(res[x]==sum&&x!=root)cnt[x]+=1;
if(res[x]==sum&&x!=root)sig+=cnt[x];
if(cnt[x]>0&&res[x]==sum&&x!=root)s3+=(cnt[x]-1);
if(res[x]==2*sum&&x!=root)s2+=(res[x]==sum?cnt[x]-1:cnt[x]);
}
int main()
{
int T;
scanf("%d",&T);
while(T--){
int n;
scanf("%d",&n);
for(int i=0;i<n+7;i++)g[i].clear();
sum=0;
for(int i=1;i<=n;i++){
int v,p;
scanf("%d%d",&v,&p);
a[i]=v;
g[i].push_back(p);
g[p].push_back(i);
if(p==0)root=i;
sum+=v;
}
if(sum%3){printf("0\n");continue;}
sum/=3;
s2=s3=sig=0;
dfs(root,0);
// for(int i=1;i<=n;i++){
// printf("id==%d res==%I64d cnt==%I64d\n",i,res[i],cnt[i]);
// }
// cout<<" s2=="<<s2<<" s3=="<<s3<<" sig=="<<sig<<endl;
ll ans=((cnt[root]*cnt[root]-sig)-s3)/2+s2;
cout<<ans<<endl;
}
return 0;
}