http://www.elijahqi.win/archives/376
You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M. (N <= 40000, M <= 100000)
In the second line there are N integers. The i-th integer denotes the weight of the i-th node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation, print its result.
Example
Input:
8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
Output:
4
4
找错加调试加学习整整7个小时啊 后来无意把N 多加一个0就过了,我也不知道这个数据是要闹哪样啊。
树上莫队仍然分块,只不过根据dfs序分块
首先将树分块,然后以所属块的编号为第一关键字,以dfs序为第二关键字对询问排序,下面只需要考虑如何由(u,v)链->(u’,v’)链了
令S(u,v)表示u~v的点的集合
S(u,v)=S(root,u) xor S(root,v) xor lca(u,v)
令 T(u,v)=S(root,u) xor S(root,v)
考虑T(u,v)->T(u,v’)
T(u,v) xor T(u,v’)=S(root,v) xor S(root,v’)
T(u,v’)=T(u,v) xor S(root,v) xor S(root,v’)
T(u,v’)=T(u,v) xor T(v,v’)
对lca单独考虑即可
之前没有学习过倍增的lca 洛谷的Lca模板还是用lct+o2优化卡过
倍增lca:首先两个不同深度的节点要先通过倍增给他们上升到同一深度,然后再采取倍增的方法,注意一定从大到小如2^3–>2^2
如何倍增,倍增使用了一个fa[哪个节点][2的i次方]
深搜时:要特殊判断,避免搜索到自己的父节点 dfs返回当前子树搜索到多少个节点,如果超过sqrt(n)那么,给这些节点分在一块内
开一个栈存储这些还没有编号的节点,最后深搜完毕,把剩下没有编号的统一再编号
处理询问的时候,我们按照已经排好序的询问来做,先处理第一个询问
剩下的询问都是在第一个询问的基础上扩展,可以由前面的公式推导得到,针对存在性取反的时候,类似以前做过的莫队题目,如果这个节点出现过,就去看这个节点的权值
最后 提醒本题需要离散化,存边的结构体开两倍
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<map>
#define N 550000
#define M 110000
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {ch=getchar();}
while (ch<='9'&&ch>='0') {x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void swap(int &x,int &y){
int t=x;x=y;y=t;
}
struct node{
int y,next;
}data[N<<1];
struct node1{
int l,r,id;
}q[M];
int n,m,n1,c[N],c1[N],dfn[N],block_num,top,f[N],ans[N],num,Log[N],low[N],fa[N][20],h[N],bl[N];
map<int,int> mm;
//int mm[N];
int stack[N],ans1;bool visit[N];
int dfs(int x){
dfn[x]=++num;int size=0;
for (int i=1;i<=Log[low[x]];++i) fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=h[x];i;i=data[i].next){
int y=data[i].y;
if (fa[x][0]==y) continue;
fa[y][0]=x;low[y]=low[x]+1;
size+=dfs(y);
if (size>=n1){
block_num++;
for (int i=1;i<=size;++i){
bl[stack[top--]]=block_num;
}
size=0;
}
}
stack[++top]=x;
return size+1;
}
inline bool cmp(node1 a,node1 b){
return bl[a.l]==bl[b.l]?dfn[a.r]<dfn[b.r]:bl[a.l]<bl[b.l];
}
inline int lca(int x,int y){
if (low[x]<low[y]) swap(x,y);
int dis=low[x]-low[y];
for (int i=0;i<=Log[dis];++i) if (dis&(1<<i)) x=fa[x][i];
if (x==y) return x;
for (int i=Log[n];i>=0;--i){
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
inline void reserve(int x){ //针对存在性取反 并且统计答案
if (visit[x]){
f[c[x]]--;if (!f[c[x]]) ans1--;visit[x]=false;
}else{
f[c[x]]++;if (f[c[x]]==1) ans1++;visit[x]=true;
}
//visit[x]^=1;
}
inline void solve(int x,int y){
while (x!=y) if (low[x]<low[y]) reserve(y),y=fa[y][0];else reserve(x),x=fa[x][0];
}
int main(){
freopen("10707.in","r",stdin);
//freopen("10707.out","w",stdout);
n=read();m=read();n1=sqrt(n);
for (int i=1;i<=n;++i) c[i]=read(),c1[i]=c[i];
sort(c1+1,c1+n+1);
// for (int i=1;i<=n;++i) printf("%d ",c[i]);
int tmp=std::unique(c1+1,c1+n+1)-c1-1;
for (int i=1;i<=tmp;++i) mm[c1[i]]=i;
for (int i=1;i<=n;++i) c[i]=mm[c[i]];
memset(h,0,sizeof(h));Log[0]=-1;
int tmp1,tmp2,num=0;for (int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;
for (int i=1;i<n;++i) {
tmp1=read();tmp2=read();
data[++num].y=tmp2;data[num].next=h[tmp1];h[tmp1]=num;
data[++num].y=tmp1;data[num].next=h[tmp2];h[tmp2]=num;
}
num=0;dfs(1);block_num++;
while (top) bl[stack[top--]]=block_num;
//for (int i=1;i<=n;++i) printf("%d ",bl[i]);
for (int i=1;i<=m;++i){q[i].l=read();q[i].r=read();q[i].id=i;if (bl[q[i].l]>bl[q[i].r]) swap(q[i].l,q[i].r);}
sort(q+1,q+m+1,cmp);
//for (int i=1;i<=m;++i) printf("%d %d\n",q[i].l,q[i].r);
/* for (int i=1;i<=n;++i){
for (int j=1;j<=4;++j) printf("%d ",fa[i][j]);
printf("\n");
}*/
tmp=lca(q[1].l,q[1].r);
// printf("%d ",tmp);
solve(q[1].l,q[1].r);
ans[q[1].id]=ans1+!(f[c[tmp]]);
for (int i=2;i<=m;++i){
solve(q[i-1].l,q[i].l);solve(q[i-1].r,q[i].r);
tmp=lca(q[i].l,q[i].r);
ans[q[i].id]=ans1+!(f[c[tmp]]);
}
for (int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}