You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.
We will ask you to perfrom the following operation:
- u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M.(N<=40000,M<=100000)
In the second line there are N integers.The ith integer denotes the weight of the ith node.
In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).
In the next M lines,each line contains two integers u v,which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation,print its result.
题意:给你一颗树,树上每个节点有一个权值,然后给你若干个询问,每次询问让你找出一条链上有多少个不同权值.
解题思路:如果这个问题是在一个序列上做,那么我们可以用莫队做,但是这个是在树上,所以我们先对树进行分块,然后同样也可以用莫队,但是树上莫队的转移就没有序列上莫队的转移那么容易了,这用需要用到Lca,然后用集合的对称差转移,还是很巧妙的。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 40000 + 10;
const int maxm = 100000 + 10;
int N, M;
int block;//块的大小
int nowblock;//当前块的编号
int sta[maxn];//存放待分配块的节点
int top;//栈顶指针
vector<int> g[maxn];//树
int pos[maxn];//节点所在的酷块的编号
int depth[maxn];//节点的深度
int value1[maxn];//节点原本的权值
int value2[maxn];//节点离散化之后的权值
int H[maxn];//用于离散化节点权值的数组
int num[maxn];//表示数字出现过的次数
int appear[maxn];//表示节点是否出现过
int father[maxn];
bool visit[maxn];
int Index[maxn<<2];
int dp[maxn<<2][25];
int First[maxn];
int Log[maxn<<2];
int res;
struct query{
int l, r, u, v, id;
bool operator <(const query &res) const{
if(l == res.l) return r < res.r;
else return l < res.l;
}
}Query[maxm];
void init()
{
res = 1;
top = 0;
father[1] = -1;
memset(visit, false, sizeof(visit));
memset(appear, 0, sizeof(appear));
memset(num, 0, sizeof(num));
block = (int)sqrt(N);
nowblock = 0;
for(int i = 1; i <= N; i++) g[i].clear();
}
void initRmq()
{
Log[0] = -1;
for(int i = 1; i < res; i++)
{
Log[i] = (i&(i - 1)) == 0?Log[i - 1] + 1:Log[i - 1];
}
for(int i = 1; i < res; i++)
{
dp[i][0] = Index[i];
}
for(int j = 1; j < 20; j++)
{
for(int i = 1; i < res&&(i + (1<<j) - 1) < res; i++)
{
dp[i][j] = (depth[dp[i][j - 1]] < depth[dp[i + (1<<(j - 1))][j - 1]])?dp[i][j - 1]:dp[i + (1<<(j - 1))][j - 1];
}
}
}
int Rmq(int l,int r)
{
int dis = r - l + 1;
int j = Log[dis];
int result = (depth[dp[l][j]] < depth[dp[r - (1<<j) + 1][j]])?dp[l][j]:dp[r - (1<<j) + 1][j];
return result;
}
void LCA(int root,int d)//获得
{
First[root] = res;
depth[root] = d;
Index[res++] = root;
visit[root] = true;
for(int i = 0; i < g[root].size(); i++)
{
int v = g[root][i];
if(!visit[v])
{
father[v] = root;
LCA(v,d + 1);
Index[res++] = root;
}
}
}
int getLca(int u, int v)
{
int f1 = First[u];
int f2 = First[v];
if(f1 > f2) swap(f1, f2);
return Rmq(f1, f2);
}
int dfs_block(int u)
{
int sum = 0;
visit[u] = true;
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(!visit[v])
{
sum += dfs_block(v);
if(sum >= block)
{
while(sum--) pos[sta[top--]] = nowblock;
sum = 0;
nowblock++;
}
}
}
sta[++top] = u;
return sum + 1;
}
void initHash()
{
int tot = 0;
for(int i = 1; i <= N; i++)
{
H[tot++] = value1[i];
}
sort(H, H + tot);
tot = unique(H, H + tot) - H;
for(int i = 1; i <= N; i++)
{
value2[i] = lower_bound(H, H + tot, value1[i]) - H + 1;
}
}
int L, R, ans;
void work(int &v)
{
if(appear[v])
{
if(--num[value2[v]] == 0) ans--;
}
else if(++num[value2[v]] == 1) ans++;
appear[v] ^= 1;
v = father[v];
}
int result[maxm];
int main()
{
//freopen("C:\\Users\\creator\\Desktop\\in1.txt","r",stdin) ;
scanf("%d%d", &N, &M);
for(int i = 1; i <= N; i++)
{
scanf("%d", &value1[i]);
}
init();
for(int i = 1; i < N; i++)
{
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
LCA(1, 0);
initRmq();
initHash();
memset(visit, false, sizeof(visit));
dfs_block(1);
while(top) pos[sta[top--]] = nowblock;
for(int i = 1; i <= M; i++)
{
int u, v;
scanf("%d%d", &u, &v);
if(pos[u] > pos[v]) swap(u, v);
Query[i].l = pos[u];
Query[i].r = First[v];
Query[i].u = u;
Query[i].v = v;
Query[i].id = i;
}
sort(Query + 1, Query + M + 1);
L = 1;
R = 1;
ans = 0;
memset(visit, false, sizeof(visit));
for(int i = 1; i <= M; i++)
{
int u = Query[i].u;
int v = Query[i].v;
int id = Query[i].id;
int lca = getLca(u, v);
int lca1 = getLca(L, u);
int lca2 = getLca(R, v);
while(L != lca1) work(L);
while(u != lca1) work(u);
while(R != lca2) work(R);
while(v != lca2) work(v);
if(num[value2[lca]] == 0) result[id] = ans + 1;
else result[id] = ans;
L = Query[i].u;
R = Query[i].v;
}
for(int i = 1; i <= M; i++)
{
printf("%d\n", result[i]);
}
return 0;
}