COT2 - Count on a tree II
You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
- u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M. (N <= 40000, M <= 100000)
In the second line there are N integers. The i-th integer denotes the weight of the i-th node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation, print its result.
Example
Input:
8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
Output:
4
4
题意:给一个树图,每个点的点权(比如颜色编号),m个询问,每个询问是一个区间[a,b],图中两点之间唯一路径上有多少个不同点权(即多少种颜色)。n<40000,m<100000。
参考http://blog.csdn.net/kuribohG/article/details/41458639
树上莫队:
(1)DFS一次,对树进行分块,分成sqrt(n)块,每个点属于一个块。并记录每个点的DFS序。
(2)将m个询问区间用所属块号作为第一关键字,DFS序作为第二关键字进行排序。
(3)转移都是差不多的,靠具体问题分析转移方式。
转移:
用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可。
AC代码:
#include<iostream>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 4e4 + 5;
const int BLOCKSIZE = 210;
int tid[maxn], dep[maxn], fa[maxn], pos[maxn], st[maxn];//tid:dfs序,dep:深度,fa:父亲节点,pos:块位置
int dfsclk, stn, blockclk;
struct query
{
int left, right, id;
query(int left,int right,int id):left(left),right(right),id(id){}
bool operator <(const query &x) const
{
return tid[right] < tid[x.right];
}
};
vector<query>v[BLOCKSIZE];
struct node
{
int en, next;
}edge[2*maxn];
int num1, head[maxn];
void init()
{
num1 = blockclk = dfsclk = stn = 0;
memset(head, -1, sizeof(head));
}
void add(int st, int en)
{
edge[num1].en = en;
edge[num1].next = head[st];
head[st] = num1++;
}
int dfs(int u, int p)
{
tid[u] = ++dfsclk;
dep[u] = dep[p] + 1;
fa[u] = p;
int sz = 0;
for (int i = head[u];i != -1;i = edge[i].next)
{
int v = edge[i].en;
if (v == p) continue;
sz += dfs(v, u);
if (sz >= BLOCKSIZE)
{
while (sz--)
pos[st[--stn]] = blockclk;
blockclk++;
}
}
st[stn++] = u;
return sz + 1;
}
int num[maxn], cnt;
int val[maxn], t[maxn];
int cross;
bool in[maxn];
void insert(int u)
{
if (!num[val[u]]) cnt++;
num[val[u]]++;
}
void del(int u)
{
num[val[u]]--;
if (!num[val[u]]) cnt--;
}
void inv(int u)
{
if (in[u])
{
del(u);
in[u] = false;
}
else
{
insert(u);
in[u] = true;
}
}
void move_up(int &u)
{
if (!cross)
{
if (in[u] && !in[fa[u]]) cross = u;
else if (!in[u] && in[fa[u]]) cross = fa[u];
}
inv(u);
u = fa[u];
}
void move_to(int u, int v)
{
if (u == v) return;
cross = 0;
if (in[v]) cross = v;
while (dep[u] > dep[v]) move_up(u);
while (dep[u] < dep[v]) move_up(v);
while (u != v)
{
move_up(u);
move_up(v);
}
inv(u);
inv(cross);
}
int ans[100005];
int main()
{
init();
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1;i <= n;i++)
{
scanf("%d", &val[i]);
t[i] = val[i];
}
sort(t + 1, t + n + 1);
int maxm = unique(t + 1, t + n + 1) - t - 1;
for (int i = 1;i <= n;i++)
val[i] = lower_bound(t + 1, t + 1 + maxm, val[i]) - t;
for (int i = 1;i < n;i++)
{
int st, en;
scanf("%d%d", &st, &en);
add(st, en), add(en, st);
}
dfs(1, 0);
while (stn--) pos[st[stn]] = blockclk;
for (int i = 0;i < m;i++)
{
int st, en;
scanf("%d%d", &st, &en);
if (tid[st] > tid[en]) swap(st, en);
v[pos[st]].push_back(query(st, en, i));
}
cnt = 0;
for (int i = 0;i < BLOCKSIZE;i++)
{
if (!v[i].size()) continue;
sort(v[i].begin(), v[i].end());
int l = v[i][0].left;
int r = l;
insert(l);
in[l] = true;
for (int j = 0;j < (int)v[i].size();j++)
{
move_to(l, v[i][j].left);
move_to(r, v[i][j].right);
l = v[i][j].left, r = v[i][j].right;
ans[v[i][j].id] = cnt;
}
move_to(l, r);
del(r);in[r] = false;
}
for (int i = 0;i < m;i++)
printf("%d\n", ans[i]);
//system("pause");
return 0;
}