题意
题意是在树的根节点有无数个军队。每周可以命令仅一只军队沿一条边移动。文最短多久可以占领整棵树。
题解
这个我是贪心做的。可以发现,如果从某个叶子A走向下一个叶子B,如果lca(A,B)到A的距离比根到lca的距离大的话,就不能从A往回走,要再来一支新的军队。有几个贪心的原则。
- 首先明确的是,当某只军队占领某个叶子结点后,他如果有多种从某个公共祖先去别的叶子结点, 且进入那些叶子结点就回不上去(由于太深)的走法,此时有几种选取下一个叶子的方法且互不相交,应该选哪个呢?应该选公共祖先最远离根的那个。因为如果公共祖先更高就会绕路变多,走的距离变大。
- 要尽可能先走深度浅的。这样更容易走更多的回马枪,避免重新派军队的浪费。根据这个原则每棵子树最后一个占领的结点应该是这棵字树深度最大的点。
- 如果有公共祖先相同的两棵树,现在他们上面都有比较深的叶子,走到一颗子树就不可能回到另一棵树的话,这个实际上随便选一棵树走就行,步数是一样的。
所以直接把每个节点的各个子节点,按最大深度升序排序,完了能得到一个叶子结点序列。只要依次处理,看走到下一个叶子要不要派新军队。
AC代码
我的排序加树剖1980ms卡过了。但网上有线性dp的做法。
#include <bits/stdc++.h>
using namespace std;
const int NN=1000100;
const int oo=2e9+10;
//int sum[NN*4],maxn[NN*4];
//int a[NN],w[NN];
int top[NN],siz[NN],son[NN],fath[NN],deep[NN];
vector <int> con[NN];
int cnt,pos[NN];
// void up(int cur){
// sum[cur]=sum[cur<<1]+sum[cur<<1|1];
// maxn[cur]=max(maxn[cur<<1],maxn[cur<<1|1]);
// }
int n;
int cntleav=0;
int leav[NN];
int minheight[NN];
bool cmp_minheight(int x,int y){
return minheight[x]<minheight[y];
}
void dfs1(int cur,int fa){
deep[cur]=deep[fa]+1;
siz[cur]=1;
son[cur]=0;
fath[cur]=fa;
int maxnn=0;
int numson=con[cur].size();
minheight[cur]=0;
if(numson==0)minheight[cur]=0;
for(int i=0;i<numson;i++){
int nex=con[cur][i];
if(nex!=fa){
dfs1(nex,cur);
minheight[cur]=max(minheight[cur],minheight[nex]+1);
siz[cur]+=siz[nex];
if(siz[nex]>maxnn){
maxnn=siz[nex];
son[cur]=nex;
}
}
}
}
void dfs2(int cur,int fa,int k){
top[cur]=k;
pos[cur]=++cnt;
//a[pos[cur]]=w[cur];
if(son[cur])dfs2(son[cur],cur,k);
int numson=con[cur].size();
for(int i=0;i<numson;i++){
int nex=con[cur][i];
if(nex!=fa&&nex!=son[cur]){
dfs2(nex,cur,nex);
}
}
}
void dfsl(int cur,int fa){
//a[pos[cur]]=w[cur];
int numson=con[cur].size();
if(numson==0)leav[++cntleav]=cur;
for(int i=0;i<numson;i++){
int nex=con[cur][i];
if(nex!=fa){
dfsl(nex,cur);
}
}
}
int lca(int u,int v){
int ans=0;
while(top[u]!=top[v]){
if(deep[top[u]]<deep[top[v]])swap(u,v);
//ans=max(ans,ask_max(1,1,n,pos[top[u]],pos[u]) );
ans+=pos[u]-pos[top[u]]+1;
u=fath[top[u]];
}
if(deep[u]<deep[v])swap(u,v);
return v;
}
int main()
{
int t;scanf("%d",&t);
for(int z=1;z<=t;z++){
int n;scanf("%d",&n);
for(int i=1;i<=n;i++)con[i].clear();
for(int i=2;i<=n;i++){
int x;scanf("%d",&x);
con[x].push_back(i);
}
cnt=0;cntleav=0;
dfs1(1,0);
for(int i=1;i<=n;i++){
//printf("%d\n",minheight[i]);
sort(con[i].begin(),con[i].end(),cmp_minheight);
}
dfsl(1,0);
dfs2(1,0,1);
int neww=1;
long long ans=0;
int lastfa=0;
for(int i=1;i<=cntleav;i++){
//printf("*%d\n",leav[i]);
if(neww==1)ans+=deep[leav[i]]-1;
else ans+=deep[leav[i-1]]+deep[leav[i]]-(deep[lastfa]<<1);
if(i==cntleav)break;
int lcfa=lca(leav[i],leav[i+1]);
if(deep[leav[i]]-deep[lcfa]>deep[lcfa]-1){
neww=1;
}
else{//printf("*%d\n",deep[leav[i]]-deep[lcfa]);
neww=0;
}lastfa=lcfa;
}
printf("Case #%d: ",z);
printf("%lld\n",ans);
}
return 0;
}