本来做的是农林的题,那题直接dsu就过了。
https://ac.nowcoder.com/acm/contest/7872/J
原意就是用来练dsu on tree的,牛客数据水不会有mle的问题。
思路:结合前面来说,我们来具体考虑暴力怎么做。
由于题目给出x,y不能互相为lca,这个条件如果直接去实现会比较困难。
转化一下,也就是说考虑每个点作为lca的贡献。
还是这张图,考虑如何暴力统计以1作为lca贡献,也就是跑一遍1的每个子树2,5,8。在跑的过程中更新map里面(map存一个数值出现的次数)次数,然后计算。
在暴力的情况下,先跑2这个子树,那么ans+=0;然后把2子树里面出现的次数放到map里面。然后跑5这个子树,ans+=map[a[1]*2-a[u]],就统计下来5号子树和2号子树里面满足题意的对数。
同理更新map,再跑8这个子树。
可以看到这样是暴力n^2的,所以上启发式合并。同时怕卡常数上了unordered_map;
(事后发现牛客的数据水不会MLE
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<unordered_map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=1e5+100;
typedef int LL;
LL a[maxn],ans=0;
LL siz[maxn],son[maxn],flag;
unordered_map<LL,LL>mp;
vector<LL>g[maxn];
void predfs(LL u,LL fa)
{
siz[u]=1;
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
if(v==fa) continue;
predfs(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]){
son[u]=v;
}
}
}
void cal(LL u,LL fa,LL lca)///枚举每一个点作为lca时候的贡献
{
ans+=mp[a[lca]*2-a[u]];
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
if(v==fa||v==flag) continue;
cal(v,u,lca);
}
}
void add(LL u,LL fa,LL val){
mp[a[u]]+=val;
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
if(v==fa||v==flag) continue;
add(v,u,val);
}
}
void dfs(LL u,LL fa,bool keep)
{
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
if(v==fa||v==son[u]) continue;
dfs(v,u,0);
}
if(son[u]){
dfs(son[u],u,1);
flag=son[u];
}
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
if(v==fa||v==son[u]) continue;
cal(v,u,u);
add(v,u,1);
}
mp[a[u]]++;///本身单独节点,如样例
flag=0;
if(keep==0){
add(u,fa,-1);
}
}
int main(void)
{
cin.tie(0);std::ios::sync_with_stdio(false);
LL n;cin>>n;
for(LL i=1;i<=n;i++){
cin>>a[i];
}
LL m=n-1;
while(m--){
LL u,v;cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
predfs(1,-1);
dfs(1,-1,0);
cout<<ans*2<<endl;
return 0;
}
然后上知乎搜了一下找到了原题,是计蒜客的。https://nanti.jisuanke.com/t/42586
然后把农林的题解发上去,MLE了。然后把牛客过的发上去,基本mle。
看了看vj其他人过的,似乎是个动态开点线段树然后启发式合并。emm目前不是很能补。
再附一个计蒜客问答里面老师给的题解。
#include <cstdio>
#include <iostream>
#include <vector>
#include <ctime>
#define N 110000
using namespace std;
int n, k, p, deep[N], son[N], v[N], size[N];
vector<int> g[N];
long long ans = 0;
struct treap {
struct node {
node *left, *right;
int value, fix, size;
node() {}
node(int _value) {
value = _value;
fix = rand();
size = 1;
left = right = NULL;
}
};
node* root;
void left_rotate(node *&a) {
node *b = a -> right;
a -> right = b -> left;
update(a);
b -> left = a;
update(b);
a = b;
}
void right_rotate(node *&a) {
node *b = a -> left;
a -> left = b -> right;
update(a);
b -> right = a;
update(b);
a = b;
}
void update(node *p) {
if (!p) return;
p -> size = 1;
if (p -> left) p -> size += p -> left -> size;
if (p -> right) p -> size += p -> right -> size;
}
int find(node *p, int k) {
if (!p) return 0;
if (p -> value <= k) {
int ret = 1;
if (p -> left) ret += p -> left -> size;
return ret + find(p -> right, k);
} else {
return find(p -> left, k);
}
}
void insert(node *&p, int value) {
if (!p) {
p = new node(value);
} else if (value <= p -> value) {
insert(p -> left, value);
if (p -> left -> fix < p -> fix) {
right_rotate(p);
}
} else {
insert(p -> right, value);
if (p -> right -> fix < p -> fix) {
left_rotate(p);
}
}
update(p);
}
void erase(node *&p, int value) {
if (p -> value == value) {
if (!p -> right || !p -> left) {
node *t = p;
if (!p -> right) {
p = p -> left;
} else {
p = p -> right;
}
delete t;
} else {
if (p -> left -> fix < p -> right -> fix) {
right_rotate(p);
erase(p -> right, value);
} else {
left_rotate(p);
erase(p -> left, value);
}
}
} else if (value < p -> value) {
erase(p -> left, value);
} else {
erase(p -> right, value);
}
update(p);
}
}T[N];
void dfs(int x, int d) {
size[x] = 1;
deep[x] = d;
for(auto t: g[x]) {
dfs(t, d + 1);
size[x] += size[t];
if (size[t] > size[son[x]]) son[x] = t;
}
}
void add(int x) {
//cout << "add " << x << endl;
T[v[x]].insert(T[v[x]].root, deep[x]);
//cout << "add finished" << endl;
for(auto t: g[x]) add(t);
}
void clear(int x, int w) {
//cout << "clear " << x << ' ' << w << endl;
if (w) T[v[x]].erase(T[v[x]].root, deep[x]);
for(auto t: g[x]) clear(t, 1);
}
void calc(int x, int w) {
int h = 2 * v[w] - v[x];
//printf("calc %d %d\n", x, w);
//cout << "calc " << x << ' ' << w << endl;
if (h >= 0 && h < N) {
//cout << "calc " << x << ' ' << w << ' ' << h << endl;
ans += T[h].find(T[h].root, k + 2 * deep[w] - deep[x]);
//cout << ans << endl;
}
for(auto t: g[x]) calc(t, w);
}
void dsu(int x, int keep) {
//cout << "son " << x << ' ' << son[x] << ' ' << keep << endl;
//cout << "deep " << x << ' ' << deep[x] << endl;
for(int t: g[x]) {
if (t == son[x]) continue;
dsu(t, 0);
}
if (son[x]) dsu(son[x], 1);
for(int t: g[x]) {
if (t == son[x]) continue;
//cout << x << ' ' << t << endl;
//printf("calc %d %d", t, x);
calc(t, x);
add(t);
}
if (!keep) {
clear(x, 0);
} else {
//cout << "insert " << x << ' ' << deep[x] << endl;
T[v[x]].insert(T[v[x]].root, deep[x]);
//cout << "insert finished" << endl;
}
}
int main() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i ++) scanf("%d", &v[i]);
for (int i = 2; i <= n; i ++) {
scanf("%d", &p);
g[p].push_back(i);
}
//if (!T[0].root) cout << "yes" << endl;
dfs(1, 0);
dsu(1, 1);
ans *= 2;
printf("%lld\n", ans);
return 0;
}