题目大意
给出多个询问u , v , 求出u-v路径上点权值不同的个数
开始做的是COT1,用主席树写过了,理解起来不难
很高兴的跑去做第二道,完全跟普通数组区间求k个不同有很大区别,完全没思路
膜拜http://www.cnblogs.com/oyking/p/4265823.html
这里利用莫队思想来做,在树上分块,尽可能让相连部分作为一个联通块,那么就在dfs过程中加个手写栈,如果回溯上来的时候保存的值的个数超过每块中应有的
个数那么就将他们分到同一个id块
排序也是跟普通莫队上一样,按分块编号排序,这里有两个端点,那么就先将编号小的摆在前面在排序
离线排序做好了,剩下就是转移的问题,从当前一条路径转移到下一条,diff保存了之前记录的路径上不同点的个数
之前是u,v , 现在去curu , curv , 那么先找到u-curu路径上的点,添加进来,v-curv路径上的点添加进来
画个图可以看出这是原来的路径基础上不需要的点会被再多访问一次,那些需要的点要么本身在原基础上不再访问,要么又多访问两次保持不变,像异或一样
所以记录当前点有没有被访问奇数次就行了,多访问就翻转
最后会有多余的点没有被访问到,就是LCA(u,v) LCA(curu , curv) , 这个要在访问一次即可
这里因为是分块的,所以跑的是曼哈顿距离的询问区间,不会超时
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define N 40010 5 int first[N] , k; 6 int block[N] , sz;//对应分到的块的编号和每块应含的大小 7 int _stack[N] , top , cursz , ID;//手写栈 8 int n , m; 9 int hash[N] , a[N] , b[N]; 10 11 struct Edge{ 12 int x , y , next; 13 Edge(){} 14 Edge(int x , int y , int next):x(x),y(y),next(next){} 15 }e[N<<1]; 16 17 void add_edge(int x , int y) 18 { 19 e[k] = Edge(x , y , first[x]); 20 first[x] = k++; 21 } 22 23 int dp[N<<1][30] , id[N<<1] , dep[N<<1] , No[N] , fa[N] , dfs_clock; 24 int depth[N]; 25 26 void add_block(int &cursz , int ID) 27 { 28 while(cursz){ 29 block[_stack[--top]] = ID; 30 // cout<<"IN: "<<ID<<" "<<_stack[top]<<" "<<top<<" "<<sz<<endl; 31 cursz--; 32 } 33 } 34 35 void dfs(int u , int f , int d) 36 { 37 id[++dfs_clock] = u , No[u] = dfs_clock , dep[dfs_clock] = d; 38 fa[u] = f , depth[u] = d; 39 for(int i=first[u] ; ~i ; i=e[i].next){ 40 int v = e[i].y; 41 if(v == f) continue; 42 dfs(v , u , d+1); 43 id[++dfs_clock] = u , dep[dfs_clock] = d; 44 } 45 //树上分块重要部分 46 cursz++; 47 _stack[top++] = u; 48 if(cursz>=sz) add_block(cursz , ++ID); 49 } 50 51 void ST(int n) 52 { 53 for(int i=1 ; i<=n ; i++) dp[i][0] = i; 54 for(int j=1 ; (1<<j)<=n ; j++){ 55 for(int i=1 ; i+(1<<j)-1<=n ; i++){ 56 int a = dp[i][j-1] , b=dp[i+(1<<(j-1))][j-1]; 57 dp[i][j] = dep[a]<dep[b]?a:b; 58 } 59 } 60 } 61 62 int RMQ(int l , int r) 63 { 64 int k=0; 65 while((1<<(k+1))<=r-l+1) k++; 66 int a = dp[l][k] , b = dp[r-(1<<k)+1][k]; 67 return dep[a]<dep[b]?a:b; 68 } 69 70 int LCA(int u , int v) 71 { 72 int x=No[u] , y=No[v]; 73 if(x>y) swap(x , y); 74 return id[RMQ(x,y)]; 75 } 76 77 void get_hash(int n){ 78 for(int i=1 ; i<=n ; i++) 79 hash[i] = lower_bound(b+1 , b+n+1 , a[i])-b; 80 } 81 82 struct Query{ 83 int u , v , id; 84 void reset(){ 85 if(block[u]>block[v]) swap(u , v); 86 } 87 bool operator<(const Query &m) const{ 88 return block[u]<block[m.u]||(block[u]==block[m.u] && block[v]<block[m.v]); 89 } 90 void in(int i){scanf("%d%d" , &u , &v);id=i;} 91 }qu[100010]; 92 93 int ans[100010] , vis[N] , cnt[N] , diff; 94 95 void xorNode(int x) 96 { 97 // cout<<"xor: "<<x<<endl; 98 if(vis[x]) vis[x]=false , diff -= (--cnt[hash[x]]==0); 99 else vis[x] = true , diff += (++cnt[hash[x]]==1); 100 } 101 102 void xorPath(int x , int y) 103 { 104 // cout<<"path: "<<x<<" "<<y<<endl; 105 if(depth[x]<depth[y]) swap(x , y); 106 while(depth[x]>depth[y]){ 107 xorNode(x); 108 x = fa[x]; 109 } 110 while(x!=y){ 111 xorNode(x) , xorNode(y); 112 x = fa[x] , y=fa[y]; 113 } 114 } 115 116 void debug() 117 { 118 // for(int i=1 ; i<=n ; i++) cout<<"i: "<<block[i]<<" "<<hash[i]<<" fa: "<<fa[i]<<" "<<depth[i]<<endl; 119 cout<<"test: "<<LCA(7,8)<<" "<<LCA(7 , 1)<<endl; 120 } 121 void solve() 122 { 123 sz = (int)sqrt(n+0.5); 124 dfs(1 , 0 , 1); 125 ST(n*2-1); 126 add_block(cursz , ++ID); 127 // debug(); 128 for(int i=0 ; i<m ; i++){ 129 qu[i].in(i); 130 qu[i].reset(); 131 } 132 sort(qu , qu+m); 133 memset(vis , 0 , sizeof(vis)); 134 diff = 0; 135 int curu=1 , curv=1; 136 xorNode(1); 137 for(int i=0 ; i<m ; i++){ 138 xorPath(curu , qu[i].u); 139 xorPath(curv , qu[i].v); 140 xorNode(LCA(curu , curv)); 141 xorNode(LCA(qu[i].u , qu[i].v)); 142 curu = qu[i].u , curv = qu[i].v; 143 ans[qu[i].id] = diff; 144 // cout<<"----华丽的分割线---"<<endl; 145 } 146 for(int i=0 ; i<m ; i++) printf("%d\n" , ans[i]); 147 } 148 149 int main() 150 { 151 // freopen("in.txt" , "r" , stdin); 152 scanf("%d%d" , &n , &m); 153 for(int i=1 ; i<=n ; i++) scanf("%d" , &a[i]) , b[i]=a[i]; 154 sort(b+1 , b+n+1); 155 get_hash(n); 156 int x , y; 157 memset(first , -1 , sizeof(first)); 158 top = cursz = ID = k = 0; 159 for(int i=1 ; i<n ; i++){ 160 scanf("%d%d" , &x , &y); 161 add_edge(x , y); 162 add_edge(y , x); 163 } 164 solve(); 165 return 0; 166 }