题面
有两棵 n n n 个点的有根树 T 1 T_1 T1, T 2 T_2 T2,根是 1 1 1 ,共用编号 1 1 1~ n n n。求最大的点集 S S S 满足每个点在 T 1 T_1 T1 中一条到根的链上,且任意两个点在 T 2 T_2 T2 中没有祖先关系。
输出该点集的大小, t t t 组数据。
1 ≤ t ≤ 3 × 1 0 5 , 2 ≤ n ≤ 3 × 1 0 5 , ∑ n ≤ 3 × 1 0 5 1\leq t\leq3\times10^5~,~2\leq n\leq 3\times 10^5~,~\sum n\leq 3\times 10^5 1≤t≤3×105 , 2≤n≤3×105 , ∑n≤3×105
题解
题面我已经转化了一部分了,继续转化:
我们处理出每个点在 T 2 T_2 T2 中的 d f s \rm dfs dfs 序,记为 d f n [ i ] {\rm dfn}[i] dfn[i],并求出每个点在 T 2 T_2 T2 的子树内节点(包括自己) d f n [ j ] {\rm dfn}[j] dfn[j] 的范围,记为 [ l i , r i ] [l_i,r_i] [li,ri] ,不妨称它为点 i i i 的区间。
那么一个集合 S S S 可行,等价于集合中的点在 T 1 T_1 T1 一条到跟的链上,且集合中每个点的区间两两无交集。
我们还可以发现一些性质:每两个点的区间要么无交集,要么存在包含关系,且一个左端点只对应一个右端点。那么我们就可以想出一个贪心策略:互相包含的区间,取较小的保留。把它挪回树上,其实就相当于存在一个点的子孙可选的时候该点不选最优一样。
接下来就简单了。我们只需要从 T 1 T_1 T1 的根开始往下 d f s \rm dfs dfs ,每到一个点往某数据结构中加入自己的区间:
- 如果被数据结构中原有的更大区间包含了,那么把那个大区间删掉,把自己加进去。
- 如果包含了数据结构中原有的至少一个小区间,那么不加自己。
- 否则,把自己加入进去,此时数据结构大小(即答案)+1。
然后遍历儿子。
回溯的时候数据结构要退回来时的状态。
具体用的数据结构就五花八门了,比较懒的可以直接用 set
之类,稍微比较聪明的可以打打树状数组或常数优秀的 zkw
线段树(有的人利用了两棵树中父亲编号小于儿子编号的特性,打出最优秀的三行朴素树状数组跑了
R
a
n
k
1
\rm Rank~1
Rank 1)。
CODE
下面是个比较好看懂的 set
代码
#include<set>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 300005
#define ENDL putchar('\n')
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
#define SI set<int>::iterator
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
int n,m,i,j,s,o,k;
vector<int> g0[MAXN];
int L[MAXN],R[MAXN],lR[MAXN],tim;
void dfs0(int x,int ff) {
L[x] = ++ tim;
for(int i = 0;i < (int)g0[x].size();i ++) {
if(g0[x][i] != ff) dfs0(g0[x][i],x);
}R[x] = tim; lR[L[x]] = R[x];
return ;
}
vector<int> g[MAXN];
int d[MAXN],dfn[MAXN],rr[MAXN],cnt,ans;
set<int> st;
void dfs(int x,int ff) {
int ad = 0;
if(st.empty()) st.insert(L[x]);
else {
SI i = st.lower_bound(L[x]);
if(i != st.begin()) {
i --;
if(lR[*i] >= R[x]) ad = *i,st.erase(ad),st.insert(L[x]);
else {
i ++;
if(i == st.end() || *i > R[x]) st.insert(L[x]);
}
}
else if(i == st.end() || *i > R[x]) st.insert(L[x]);
}
ans = max(ans,(int)st.size());
for(int i = 0;i < (int)g[x].size();i ++) {
if(g[x][i] != ff)
dfs(g[x][i],x);
}
if(st.find(L[x]) != st.end()) st.erase(L[x]);
if(ad) st.insert(ad);
return ;
}
int main() {
int T = read();
while(T --) {
n = read();
tim = 0; cnt = 0;
st.clear();
for(int i = 1;i <= n;i ++) {
g0[i].clear();g[i].clear();lR[i] = 0;
}
for(int i = 2;i <= n;i ++) {
s = read(); g[s].push_back(i);
}
for(int i = 2;i <= n;i ++) {
s = read(); g0[s].push_back(i);
}
dfs0(1,0);ans = 0;dfs(1,0);
printf("%d\n",ans);
}
return 0;
}