题意:n个节点形成的一棵树。每个节点有一个值。m次查询,求出(u,v)路径上出现了多少个不同的数。
树上的莫队算法,同样将树分成siz=sqrt(n)块,然后离线操作。先对树dfs一遍,每当子树节点个数num>=siz,就将这num个分成一块。读取所有的查询按左端点所在块排序。
重点在于怎么进行区间转移,对路径的lca特殊处理,参考博客http://blog.csdn.net/kuribohg/article/details/41458639
用倍增法求lca单次要用logn复杂度,要跑3200ms。有个地方可以优化,就是知道了所有的查询,也就是事先知道了转移路径,可以用离线的方法求O(n)求出所有需要用到的lca,这个写起来比较麻烦,不过可以优化到1800ns。代码写的比较挫。。。。
logn求lca:3200+ms
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
const int maxn=4e4+10;
const int maxm=1e5+10;
int n,m, siz;
vector<int> g[maxn];
int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn];
int fa[maxn][20], dep[maxn];
struct Query
{
int l, r, id;
int st,ed;
bool operator <(const Query& a) const
{
return st!=a.st? st<a.st: ed<a.ed; //先按左端点所在块先后排序,其次考虑又右端点所在块
}
};
Query q[maxm];
int tag, bel[maxn];
int st[maxn], top;
int dfs(int u, int par, int d, int &cnt)
{
dep[u]=d; fa[u][0]=par;
int num=0;
for(int i=0; i<g[u].size(); i++){
int v=g[u][i];
if(v!=par){
num+=dfs(v, u, d+1, cnt);
if(num>=siz){ //子树大小>=sqrt(n),分成一块
for(int i=0; i<num; i++)
bel[st[--top]]=tag;
tag++;
num=0;
}
}
}
st[top++]=u;//记录子树遍历的点
return num+1;
}
void init()
{
for(int i=0; i<=n; i++) g[i].clear();
memset(tot, 0, sizeof(tot));
memset(in, 0, sizeof(in));
siz=sqrt(n);
for(int i=1;i<=n; i++) scanf("%d",&a[i]), b[i]=a[i];
sort(b+1, b+n+1);
for(int i=1; i<=n; i++)
a[i]=lower_bound(b+1, b+n+1, a[i])-b;
for(int i=0; i<n-1; i++){
int u,v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
int cnt=0; tag=top=0;
int num=dfs(1, -1, 0, cnt);
for(int i=0; i<num; i++)
bel[st[--top]]=tag; //最后剩下的数也分成一块
for(int i=1; i<20; i++){
for(int u=1; u<=n; u++)
if(fa[u][i-1]==-1)
fa[u][i]=-1;
else fa[u][i]=fa[fa[u][i-1]][i-1];
}
for(int i=0; i<m; i++){
scanf("%d%d", &q[i].l, &q[i].r);
if(bel[q[i].l]>bel[q[i].r])
swap(q[i].l, q[i].r);
q[i].id=i;
q[i].st=bel[q[i].l];
q[i].ed=bel[q[i].r];
}
sort(q, q+m);
}
int lca(int u, int v)
{
if(dep[u]>dep[v]) swap(u, v);
for(int i=0; i<20; i++)
if((dep[v]-dep[u])>>i&1)
v=fa[v][i];
if(u==v) return u;
for(int i=19; i>=0; i--){
if(fa[u][i]!=fa[v][i]){
u=fa[u][i];
v=fa[v][i];
}
}
return fa[u][0];
}
void solve()
{
int res=0;
int cu=1, cv=1;
for(int i=0; i<m; i++){
int nu=q[i].l, nv=q[i].r;
int par=lca(cu, nu);
while(cu!=par){
if(in[cu]){
if(--tot[a[cu]]==0)
res--;
}
else if(++tot[a[cu]]==1)
res++;
in[cu]^=1;
cu=fa[cu][0];
}
cu=nu;
while(cu!=par){
if(in[cu]){
if(--tot[a[cu]]==0)
res--;
}
else if(++tot[a[cu]]==1)
res++;
in[cu]^=1;
cu=fa[cu][0];
}
cu=nu;
par=lca(cv, nv);
while(cv!=par){
if(in[cv]){
if(--tot[a[cv]]==0)
res--;
}
else if(++tot[a[cv]]==1)
res++;
in[cv]^=1;
cv=fa[cv][0];
}
cv=nv;
while(cv!=par){
if(in[cv]){
if(--tot[a[cv]]==0)
res--;
}
else if(++tot[a[cv]]==1)
res++;
in[cv]^=1;
cv=fa[cv][0];
}
cv=nv;
par=lca(cu, cv);
ans[q[i].id]=res+(!tot[a[par]]);
}
}
int main()
{
while(scanf("%d%d", &n, &m)==2){
init();
solve();
for(int i=0; i<m; i++)
printf("%d\n", ans[i]);
}
return 0;
}
离线查询lca:1800+ms
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
typedef pair<int,int> P;
#define fir first
#define sec second
const int maxn=4e4+10;
const int maxm=1e5+10;
int n,m, siz;
vector<int> g[maxn];
int first[maxn],ltot=0, nxt[6*maxm];
P lq[6*maxm];//所有需要查询的lca,lq[i].first保存v,second保存查询的id
int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn], fa1[maxn];
int fa[maxn], lca[3*maxm], col[maxn];
int bel[maxn],st[maxn],top=0;
struct Query
{
int l, r, id;
int st,ed;
bool operator <(const Query& a) const
{
return st!=a.st? st<a.st: ed<a.ed;
}
};
Query q[maxm];
int tag;
int dfs(int u, int par, int &cnt)//分块
{
fa1[u]=par;
int num=0;
for(int i=0; i<g[u].size(); i++){
int v=g[u][i];
if(v!=par)
num+=dfs(v, u, cnt);
if(num>=siz){
for(int i=0; i<num; i++)
bel[st[--top]]=tag;
tag++;
num=0;
}
}
st[top++]=u;
return num+1;
}
int find(int u)
{
return fa[u]==u?u:(fa[u]=find(fa[u]));
}
int unite(int x, int y)
{
x=fa[x];
y=fa[y];
fa[y]=x;
}
void dfs2(int u, int par)//离线查询所有lca
{
col[u]=1;
for(int i=first[u]; i!=-1; i=nxt[i]){
int v=lq[i].fir, id=lq[i].sec;
if(!col[v]) continue;
else if(col[v]==1){
lca[id]=v;
}
else{
lca[id]=find(v);
}
}
for(int i=0; i<g[u].size(); i++){
int v=g[u][i];
if(v!=par)
dfs2(v, u);
}
col[u]=2;
unite(par, u);
}
void add(int u, int v, int id)//查询m<=1e5,数比较多所以用前向星实现优化
{
lq[ltot]=P(v,id);
nxt[ltot]=first[u];
first[u]=ltot++;
}
void init()
{
for(int i=0; i<=n; i++) g[i].clear();
memset(tot, 0, sizeof(tot));
memset(in, 0, sizeof(in));
siz=sqrt(n);
for(int i=1;i<=n; i++) scanf("%d", a+i), b[i]=a[i];
sort(b+1, b+n+1);
for(int i=1; i<=n; i++)
a[i]=lower_bound(b+1, b+n+1, a[i])-b;
for(int i=0; i<n-1; i++){
int u,v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
int cnt=0; top=0; tag=0;
int num=dfs(1, -1, cnt);
for(int i=0; i<num; i++)
bel[st[--top]]=tag;
for(int i=0; i<m; i++){
scanf("%d%d", &q[i].l, &q[i].r);
if(bel[q[i].l]>bel[q[i].r])
swap(q[i].l, q[i].r);
q[i].id=i;
q[i].st=bel[q[i].l];
q[i].ed=bel[q[i].r];
}
sort(q, q+m);
cnt=0; ltot=0;
memset(first, -1, sizeof(first));
add(1, q[0].l, cnt);
add(q[0].l, 1, cnt++);
add(1, q[0].r, cnt);
add(q[0].r, 1, cnt++);
add(q[0].r, q[0].l, cnt);
add(q[0].l, q[0].r, cnt++);
//add(q[0].r, q[0].l, cnt++);
for(int i=0; i<m-1; i++){
add(q[i].l, q[i+1].l, cnt);//第i个查询左端点向第i+1个左端点转移,所以需要它们之间的lca
add(q[i+1].l, q[i].l, cnt++);
add(q[i].r, q[i+1].r, cnt);//第i个查询右端点向第i+1个右端点转移
add(q[i+1].r, q[i].r, cnt++);
add(q[i+1].r, q[i+1].l, cnt);//左端点和右端点的lca
add(q[i+1].l, q[i+1].r,cnt++);
}
for(int i=0; i<=n; i++) fa[i]=i;
memset(col, 0, sizeof(col));
dfs2(1, 0);
}
void solve()
{
int res=0;
int cu=1, cv=1;
for(int i=0; i<m; i++){
int nu=q[i].l, nv=q[i].r;
//cout<<lca[i*3]<<' '<<lca[i*3+1]<<' '<<lca[i*3+2]<<endl;
int par=lca[i*3];
while(cu!=par){
if(in[cu]){
if(--tot[a[cu]]==0)
res--;
}
else if(++tot[a[cu]]==1)
res++;
in[cu]^=1;
cu=fa1[cu];
}
cu=nu;
while(cu!=par){
if(in[cu]){
if(--tot[a[cu]]==0)
res--;
}
else if(++tot[a[cu]]==1)
res++;
in[cu]^=1;
cu=fa1[cu];
}
cu=nu;
par=lca[i*3+1];
while(cv!=par){
if(in[cv]){
if(--tot[a[cv]]==0)
res--;
}
else if(++tot[a[cv]]==1)
res++;
in[cv]^=1;
cv=fa1[cv];
}
cv=nv;
while(cv!=par){
if(in[cv]){
if(--tot[a[cv]]==0)
res--;
}
else if(++tot[a[cv]]==1)
res++;
in[cv]^=1;
cv=fa1[cv];
}
cv=nv;
par=lca[i*3+2];
ans[q[i].id]=res+(!tot[a[par]]);
}
}
int main()
{
while(cin>>n>>m){
init();
solve();
for(int i=0; i<m; i++)
printf("%d\n", ans[i]);
}
return 0;
}