IOI 系列的题目就不提示啦……
博主一开始想找一个时间复杂度固定的算法来求解这道题(莫队神马的本题根本不能用,这玩意原来是个交互题,所以强制在线)。后来翻年鉴发现此题需要利用多个算法来平衡时间复杂度。我们一一来说。
首先,我们用 DFS 序 ,来把问题转化成一个区间问题。此时两种颜色分别为 x,y , 每种颜色在树中分别有 a,b 个点
算法一:
事先我们预处理颜色为
y
的点,将他们按照
时间复杂度:
算法二:
事先我们差分颜色为
x
的点的子树在
时间复杂度:
算法三:
我们将颜色为
x
的区间和颜色为
时间复杂度为: O(a+b)
最后值得一提的是,我们需要记忆化得到的答案,这样就保证了如果经常查询大数据不会超时。(因为大数据的种类数是很有限的)
而在每种询问选择算法的时候可以直接比较此时的渐进时间复杂度来选择,至于总时间复杂度按照年鉴的分析是 O(nnlog(n)−−−−−−−√) ,有兴趣的小伙伴可以参考年鉴。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+1e2;
vector<int> g[maxn] , b[maxn] , s[maxn]; vector<pair<int , int > > t[maxn];
int n , r , q , dfsCnt , id[maxn] , reid[maxn] , c[maxn] , Size[maxn];
void dfs(int x)
{
Size[x] = 1;
reid[id[x] = ++dfsCnt] = x;
b[c[x]].push_back(id[x]);
for(int i=0;i<g[x].size();i++) dfs(g[x][i]) , Size[x] += Size[g[x][i]];
}
int x , y;
map<pair<int , int> , long long> dic;
void solve1()
{
long long res = 0;
for(int i=0,l,r;i<b[x].size();i++)
{
l = b[x][i]; r = l + Size[reid[l]] - 1;
l = lower_bound(b[y].begin() , b[y].end() , l) - b[y].begin();
r = upper_bound(b[y].begin() , b[y].end() , r) - b[y].begin() - 1;
res += r-l+1;
}
dic[make_pair(x , y)] = res;
printf("%lld\n" , res);
}
void solve2()
{
long long res = 0;
for(int i=0,now;i<b[y].size();i++)
{
now = b[y][i];
now = lower_bound(t[x].begin() , t[x].end() , make_pair(now , 10)) - t[x].begin() - 1;
res += now < 0 ? 0 : s[x][now];
}
dic[make_pair(x , y)] = res;
printf("%lld\n" , res);
}
pair<int , int> li[maxn]; int top;
void solve3()
{
long long res = 0;
int cnt1 = 0 , cnt2 = top = 0;
while(cnt1 != t[x].size() || cnt2 != b[y].size())
{
if(cnt1 == t[x].size()) li[top++] = make_pair(b[y][cnt2++] , 2);
else if(cnt2 == b[y].size()) li[top++] = t[x][cnt1++];
else if(t[x][cnt1].first <= b[y][cnt2]) li[top++] = t[x][cnt1++];
else li[top++] = make_pair(b[y][cnt2++] , 2);
}
for(int i=0,s=0;i<top;i++)
if(li[i].second == 2) res += s;
else s += li[i].second;
dic[make_pair(x , y)] = res;
printf("%lld\n" , res);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in","r",stdin);
#endif
cin>>n>>r>>q;
for(int i=1;i<=n;i++)
{
if(i-1)
{
int fa;
scanf("%d" , &fa);
g[fa].push_back(i);
}
scanf("%d" , c+i);
}
dfs(1);
for(int i=1;i<=r;i++)
{
for(int j=0;j<b[i].size();j++) t[i].push_back(make_pair(b[i][j] , 1)) , t[i].push_back(make_pair(b[i][j] + Size[reid[b[i][j]]] , -1));
sort(t[i].begin() , t[i].end()); sort(b[i].begin() , b[i].end());
for(int j=0,sum=0;j<t[i].size();j++) s[i].push_back(sum += t[i][j].second);
}
while(q--)
{
scanf("%d%d" , &x , &y);
if(dic.count(make_pair(x , y))) printf("%lld\n" , dic[make_pair(x , y)]);
else
{
int c1 = b[x].size() * log2(b[y].size());
int c2 = b[y].size() * log2(t[x].size());
int c3 = t[x].size() + b[y].size();
int mn = min(c1 , min(c2 , c3));
if(mn == c1) solve1();
else if(mn == c2) solve2();
else solve3();
}
}
return 0;
}