题目描述
NIO is playing a game about trees.
The game has two trees A,BA, BA,B each with NNN vertices. The vertices in each tree are numbered from 111 to NNN and the iii-th vertex has the weight viv_ivi. The root of each tree is vertex 1. Given KKK key numbers x1,…,xkx_1,\dots,x_kx1,…,xk, find the number of solutions that remove exactly one number so that the weight of the lowest common ancestor of the vertices in A with the remaining key numbers is greater than the weight of the lowest common ancestor of the vertices in B with the remaining key numbers.
输入描述:
The first line has two positive integers N,K(2≤K≤N≤105)N,K (2 \leq K \leq N \leq 10^5)N,K(2≤K≤N≤105). The second line has KKK unique positive integers x1,…,xK(xi≤N)x_1,\dots,x_K (x_i \leq N)x1,…,xK(xi≤N). The third line has NNN positive integers ai(ai≤109)a_i (a_i \leq 10^9)ai(ai≤109) represents the weight of vertices in A. The fourth line has N−1N - 1N−1 positive integers {pai}\{pa_i\}{pai}, indicating that the number of the father of vertices i+1i+1i+1 in tree A is paipa_ipai. The fifth line has nnn positive integers bi(bi≤109)b_i (b_i \leq 10^9)bi(bi≤109) represents the weight of vertices in B. The sixth line has N−1N - 1N−1 positive integers {pbi}\{pb_i\}{pbi}, indicating that the number of the father of vertices i+1i+1i+1 in tree B is pbipb_ipbi.
输出描述:
One integer indicating the answer.
示例1
输入
5 3 5 4 3 6 6 3 4 6 1 2 2 4 7 4 5 7 7 1 1 3 2
输出
1
说明
In first case, the key numbers are 5,4,3. Remove key number 5, the lowest common ancestors of the vertices in A with the remaining key numbers is 2, in B is 3. Remove key number 4, the lowest common ancestors of the vertices in A with the remaining key numbers is 2, in B is 1. Remove key number 3, the lowest common ancestors of the vertices in A with the remaining key numbers is 4, in B is 1. Only remove key number 5 satisfies the requirement.
示例2
输入
10 3 10 9 8 8 9 9 2 7 9 0 0 7 4 1 1 2 4 3 4 2 4 7 7 7 2 3 4 5 6 1 5 3 1 1 3 1 2 4 7 3 5
输出
2
题意: 给出两颗树A和B,点的编号都是从1~n,每棵树上的点都有各自的价值wi,再给出一个包含k个点的关键点集合,现在可以移除关键点集合中的一个点,使得剩下关键点在A树上的LCA的点权大于在B树上的LCA的点权,求多少种方案符合要求。
分析: 比赛的时候想到的是比较麻烦的模拟做法,根据树的形态来分类讨论,其实这样是非常麻烦的,赛后看到了正解,正解是维护前缀LCA和后缀LCA,这样就能很快求出剩下点的LCA了,时间复杂度为O(nlogn)。具体做法就是先建好两棵树,然后对于A树维护一个关键点的前缀LCA数组front[i][1],再维护一个后缀LCA数组back[i][1],对于B树同理得到front[i][2],back[i][2],然后就可以枚举移除哪个关键点了,剩下点在A树中的LCA就是lca(front[i-1][1], back[i+1][1]),在B树中的LCA就是lca(front[i-1][2], back[i+1][2]),注意边界的特判。
具体代码如下:
正解:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <string>
#include <cstring>
#include <vector>
using namespace std;
vector<int> g[100005][3];
int n, k, w[100005][3], fa[100005][21][3], dep[100005][3], a[100005];
int front[100005][3], back[100005][3];//前缀lca和后缀lca
void dfs(int now, int pre, int type)
{
dep[now][type] = dep[pre][type]+1;
fa[now][0][type] = pre;
for(int i = 1; i <= 20; i++)//超出范围的祖先都是0号结点
fa[now][i][type] = fa[fa[now][i-1][type]][i-1][type];
for(int i = 0; i < g[now][type].size(); i++)
if(g[now][type][i] != pre)
dfs(g[now][type][i], now, type);
}
int lca(int x, int y, int type)
{
if(dep[x][type] < dep[y][type]) swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[fa[x][i][type]][type] >= dep[y][type])
x = fa[x][i][type];
if(x == y) return x;
for(int i = 20; i >= 0; i--)
if(fa[x][i][type] != fa[y][i][type])
x = fa[x][i][type], y = fa[y][i][type];
return fa[x][0][type];
}
signed main()
{
cin >> n >> k;
for(int i = 1; i <= k; i++) scanf("%d", &a[i]);
for(int i = 1; i <= n; i++) scanf("%d", &w[i][1]);
for(int i = 2; i <= n; i++){
int u;
scanf("%d", &u);
g[u][1].push_back(i);
}
for(int i = 1; i <= n; i++) scanf("%d", &w[i][2]);
for(int i = 2; i <= n; i++){
int u;
scanf("%d", &u);
g[u][2].push_back(i);
}
dfs(1, 0, 1);
front[0][1] = a[1];
back[k+1][1] = a[k];
for(int i = 1; i <= k; i++) front[i][1] = lca(front[i-1][1], a[i], 1);
for(int i = k; i >= 1; i--) back[i][1] = lca(back[i+1][1], a[i], 1);
dfs(1, 0, 2);
front[0][2] = a[1];
back[k+1][2] = a[k];
for(int i = 1; i <= k; i++) front[i][2] = lca(front[i-1][2], a[i], 2);
for(int i = k; i >= 1; i--) back[i][2] = lca(back[i+1][2], a[i], 2);
int ans = 0;
for(int i = 1; i <= k; i++){//枚举移除哪个点
int top1, top2;
if(i < k && i > 1){
top1 = lca(front[i-1][1], back[i+1][1], 1);
top2 = lca(front[i-1][2], back[i+1][2], 2);
}
else if(i == 1){
top1 = back[2][1];
top2 = back[2][2];
}
else{
top1 = front[k-1][1];
top2 = front[k-1][2];
}
if(w[top1][1] > w[top2][2]) ans++;
}
printf("%d\n", ans);
return 0;
}
比赛时的模拟做法:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <string>
#include <vector>
#define inf 0x3f3f3f3f
using namespace std;
int a[100005], w[100005][3];
int num[100005][3];
bool flag[100005];
int dep[100005][3], fa[100005][21][3];
int F[100005][3];
vector<int> tr[100005][3];
int n, k;
vector<int> temp;
void dfs(int now, int pre, int type){
dep[now][type] = dep[pre][type]+1;
fa[now][0][type] = pre;
for(int i = 1; i <= 20; i++)//超出范围的祖先都是0号结点
fa[now][i][type] = fa[fa[now][i-1][type]][i-1][type];
for(int i = 0; i < tr[now][type].size(); i++)
if(tr[now][type][i] != pre)
dfs(tr[now][type][i], now, type);
}
int lca(int x, int y, int type){
if(dep[x][type] < dep[y][type]) swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[fa[x][i][type]][type] >= dep[y][type])
x = fa[x][i][type];
if(x == y) return x;
for(int i = 20; i >= 0; i--)
if(fa[x][i][type] != fa[y][i][type])
x = fa[x][i][type], y = fa[y][i][type];
return fa[x][0][type];
}
void dfs2(int now, int fa, int type, int super_fa){
if(flag[now]){
num[now][type]++;
F[now][type] = super_fa;
}
for(int i = 0; i < tr[now][type].size(); i++){
int to = tr[now][type][i];
dfs2(to, now, type, super_fa);
num[now][type] += num[to][type];
}
}
void dfs3(int now, int type){
if(flag[now]) temp.push_back(now);
for(int i = 0; i < tr[now][type].size(); i++){
int to = tr[now][type][i];
dfs3(to, type);
}
}
signed main()
{
scanf("%d%d", &n, &k);
for(int i = 1; i <= k; i++){
scanf("%d", &a[i]);
flag[a[i]] = true;
}
for(int i = 1; i <= n; i++) scanf("%d", &w[i][1]);
for(int i = 2; i <= n; i++){
int to;
scanf("%d", &to);
tr[to][1].push_back(i);
}
for(int i = 1; i <= n; i++) scanf("%d", &w[i][2]);
for(int i = 2; i <= n; i++){
int to;
scanf("%d", &to);
tr[to][2].push_back(i);
}
dfs(1, 0, 1);
dfs(1, 0, 2);
int lca1 = a[1], lca2 = a[1];
for(int i = 2; i <= k; i++){
lca1 = lca(a[i], lca1, 1);
lca2 = lca(a[i], lca2, 2);
}
int cnt1 = 0;
for(int i = 0; i < tr[lca1][1].size(); i++){
int to = tr[lca1][1][i];
dfs2(to, lca1, 1, to);
if(num[to][1]) cnt1++;
}
int cnt2 = 0;
for(int i = 0; i < tr[lca2][2].size(); i++){
int to = tr[lca2][2][i];
dfs2(to, lca2, 2, to);
if(num[to][2]) cnt2++;
}
int ans = 0;
for(int i = 1; i <= k; i++){
int t1 = -1, t2 = -1;
if(cnt1 > 2) t1 = w[lca1][1];
else if(a[i] == lca1){
if(cnt1 == 2) t1 = w[lca1][1];
else if(cnt1 == 1){
temp.clear();
for(int j = 1; j <= k; j++){
if(a[j] == lca1) continue;
temp.push_back(a[j]);
}
int top = temp[0];
for(int j = 0; j < temp.size(); j++)
top = lca(top, temp[j], 1);
t1 = w[top][1];
}
//此时cnt1不可能为0
}
else if(cnt1 == 2){
if(num[F[a[i]][1]][1] > 1) t1 = w[lca1][1];//lca不变
else{
temp.clear();
for(int j = 1; j <= k; j++){
if(a[j] == a[i]) continue;
temp.push_back(a[j]);
}
int top = temp[0];
for(int j = 0; j < temp.size(); j++)
top = lca(top, temp[j], 1);
t1 = w[top][1];
}
}
else if(cnt1 == 1)
t1 = w[lca1][1];
if(cnt2 > 2) t2 = w[lca2][2];
else if(a[i] == lca2){
if(cnt2 == 2) t2 = w[lca2][2];
else if(cnt2 == 1){
temp.clear();
for(int j = 1; j <= k; j++){
if(a[j] == lca2) continue;
temp.push_back(a[j]);
}
int top = temp[0];
for(int j = 0; j < temp.size(); j++)
top = lca(top, temp[j], 2);
t2 = w[top][2];
}
//此时cnt1不可能为0
}
else if(cnt2 == 2){
if(num[F[a[i]][2]][2] > 1) t2 = w[lca2][2];//lca不变
else{
temp.clear();
for(int j = 1; j <= k; j++){
if(a[j] == a[i]) continue;
temp.push_back(a[j]);
}
int top = temp[0];
for(int j = 0; j < temp.size(); j++)
top = lca(top, temp[j], 2);
t2 = w[top][2];
}
}
else if(cnt2 == 1)
t2 = w[lca2][2];
if(t1 > t2) ans++;
}
printf("%d\n", ans);
return 0;
}