题目大意: 有一棵大小为 n n n 的树,枚举每一条边,求出删掉这条边之后两棵子树的重心的编号之和。
题解
考虑求出每个点作为重心的次数,先将原树的重心作为根节点。
先求出每个节点的子树大小(记为 s [ x ] s[x] s[x]),然后设 g [ x ] g[x] g[x] 表示 x x x 的儿子中 s s s 的最大值。
考虑非根节点的点。对于点
x
x
x,设
S
S
S 表示删掉一条边后
x
x
x 不在的那棵树的大小,如果删掉一条边之后能使
x
x
x 变成重心,需要满足这两个不等式:
{
2
×
g
[
x
]
≤
n
−
S
2
×
(
n
−
S
−
s
[
x
]
)
≤
n
−
S
\begin{cases} 2\times g[x]\leq n-S\\ 2\times (n-S-s[x])\leq n-S \end{cases}
{2×g[x]≤n−S2×(n−S−s[x])≤n−S
即需要满足
x
x
x 的任意一个子树大小都不超过
⌊
n
−
S
2
⌋
\lfloor\dfrac {n-S} 2 \rfloor
⌊2n−S⌋,解出来就是:
n
−
2
s
[
x
]
≤
S
≤
n
−
2
g
[
x
]
n-2s[x]\leq S\leq n-2g[x]\\
n−2s[x]≤S≤n−2g[x]
上面这个东西用权值树状数组维护即可,每次求一下在该区间内有多少个 S S S。
以及还有一个限制,就是删掉的边显然不能在 x x x 的子树内,所以要再用另外一个树状数组维护一下子树内的贡献。
然后考虑根节点的贡献,设
a
,
b
a,b
a,b 表示根节点的最大和次大子树的大小,因为需要满足删掉一条边后根节点的任意子树大小都不超过树的大小的一半
,于是分类讨论一下:
- 假如删掉了最大子树内的一条边。需要满足 2 b ≤ n − S 2b\leq n-S 2b≤n−S,即 S ≤ n − 2 b S\leq n-2b S≤n−2b。
- 假如删掉了次大子树内的一条边,需要满足 2 a ≤ n − S 2a\leq n-S 2a≤n−S,即 S ≤ n − 2 a S\leq n-2a S≤n−2a。
这个在dfs时顺便求一下即可。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 300010
#define MS(f) memset(f,0,sizeof(f))
int T,n,a,b;long long ans;
struct edge{int y,next;}e[maxn<<1];
int first[maxn],len;
void buildroad(int x,int y){e[++len]=(edge){y,first[x]};first[x]=len;}
int s[maxn],g[maxn],rt;
void dfs1(int x,int fa)
{
s[x]=1;g[x]=0;bool v=true;
for(int i=first[x];i;i=e[i].next)
{
int y=e[i].y;if(y==fa)continue;
dfs1(y,x);s[x]+=s[y];g[x]=max(g[x],s[y]);
if(s[y]>(n>>1))v=false;
}
if(n-s[x]>(n>>1))v=false; if(v)rt=x;
}
int tr1[maxn],tr2[maxn];
//tr1记录全局范围内的每个S的出现次数,tr2记录遍历过的所有边对应的S
void add(int *tr,int x,int y){for(x+=2;x<=n+2;x+=(x&-x))tr[x]+=y;}
int sum(int *tr,int x){int re=0;for(x+=2;x;x-=(x&-x))re+=tr[x];return re;}
int getsum(int *tr,int x,int y){return sum(tr,y)-sum(tr,x-1);}
bool Big[maxn];
void dfs2(int x,int fa)
{
add(tr1,s[fa],-1);add(tr1,n-s[x],1);Big[x]|=Big[fa];
//注意这里不加1ll*会WA
if(x!=rt)ans+=1ll*x*(getsum(tr1,n-2*s[x],n-2*g[x])+getsum(tr2,n-2*s[x],n-2*g[x]));
if(x!=rt)ans+=1ll*rt*(s[x]<=n-2*(Big[x]?b:a));
add(tr2,s[x],1);
for(int i=first[x];i;i=e[i].next)if(e[i].y!=fa)dfs2(e[i].y,x);
if(x!=rt)ans-=1ll*x*getsum(tr2,n-2*s[x],n-2*g[x]);
//上面加上进来时的tr2,这里减去出去时的tr2,实际上减去的就是子树内的S贡献
add(tr1,s[fa],1);add(tr1,n-s[x],-1);
}
int main()
{
scanf("%d",&T);while(T--)
{
scanf("%d",&n);MS(first);len=ans=0;
for(int i=1,x,y;i<n;i++)scanf("%d %d",&x,&y),
buildroad(x,y),buildroad(y,x); dfs1(1,0);dfs1(rt,0);
MS(tr1);MS(tr2);for(int i=1;i<=n;i++)add(tr1,s[i],1);
a=b=0;for(int i=first[rt];i;i=e[i].next)
if(s[e[i].y]>=a)b=a,a=s[e[i].y];else if(s[e[i].y]>=b)b=s[e[i].y];
MS(Big);for(int i=first[rt];i;i=e[i].next)Big[e[i].y]=(s[e[i].y]==a);
dfs2(rt,0);printf("%lld\n",ans);
}
}