题意:有两颗n个点的树,找出最大子集,满足如下条件:
- 该点集在树1上为一条连续的链
- 该点集在树2上,两两无任何祖先关系
思路:
对于条件2:
点u如果是点v的祖先,那么u的子树一定包含v,即一定包含v的子树,涉及子树的问题可以考虑dfs序。
dfs序性质:dfs序可以将每颗子树划分为若干区间[l, r],且划分的若干区间只会有两种情况:一个区间完全包含另一个区间,两个区间两两不相交。
问题就转化为,集合中点的范围,两两不相交,且在树1成一条链。范围不相交,就可以用线段树或者set判断。
对于条件1:
初始想法就是:dfs时每个点依次加入,维护区间和判断和目前点集有无相交部分。但是树1的链是会从中间断开的,不合法时就需要回溯删掉集合中的某些点,直到集合中点无相交。写主席树为的就是可以直接得到任意链的线段树。
ABCDEF 到G 不合法时,就从A开始逐渐删除,直到找到合法的链头。(这样暴力做可以被长链+菊花卡成 n 2 n^2 n2,所以可以将暴力改为二分找合法链头,复杂度就是稳定最坏n l o g 2 n log^2n log2n)
Tips:因为这里用主席树维护的是区间和,区间修改有lazy标记,所以不能像普通的pushdown一样向下传递,更新子树,因为之后版本在向下传递lazy标记的时候,把原来树的状态给改变了(折磨了我好久…)。
解决办法:标记永久化,查找区间的时候,保证更新的区间,一定是你要找的区间(递归的时候写三个的那种)。然后一路更新下来,所有包含[l, r]的区间都是新开的节点。模板题hdu4348。
然后还有一种比较优美的解法就是:树上滑窗,精髓在于答案是单调的,所以窗口大小也是只增不减,回退也是回退窗口,每个点最多只进出一次,配合上线段树就是nlogn的复杂度,但是常数有点大。
我的代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
const int MAXN = 3e5 + 10;
const int N = 3e5 +10;
const int MAXM = 1e6 + 10;
const int INF=0x3f3f3f3f;
const ll LINF=0x3f3f3f3f3f3f3f3f;
const int NINF=-INF-1;
const ll mod=1e9+7;
#define PI acos(-1.0)
int T,n,first[MAXN],last[MAXN],top;
int sz[MAXN];
vector<int> e1[MAXN],e2[MAXN];
void getdfn(int u,int fa)
{
++top;
first[u]=top;
sz[u]=1;
for(auto it:e2[u])
{
if(it==fa) continue;
getdfn(it,u);
sz[u]+=sz[it];
}
return ;
};
int idx,root[MAXN],tim;
struct node{
int lc,rc;
int sum,lz;
}tr[40*N];
int build(int l,int r)
{
int p=++idx;
tr[p].sum=tr[p].lz=0;
if(l==r)
{
return p;
}
int mid=l+r>>1;
tr[p].lc=build(l,mid),tr[p].rc=build(mid+1,r);
tr[p].sum=tr[tr[p].lc].sum+tr[tr[p].rc].sum;
return p;
}
int update(int p,int l,int r,int x,int y,int val)
{ int q=++idx;
tr[q]=tr[p];
tr[q].sum=tr[p].sum+val*(y-x+1);
if(x==l&&y==r)
{
tr[q].lz+=val;
return q;
}
int mid=l+r>>1;
if(x>mid) tr[q].rc=update(tr[p].rc,mid+1,r,x,y,val);
else if(y<=mid) tr[q].lc=update(tr[p].lc,l,mid,x,y,val);
else{
tr[q].lc=update(tr[p].lc,l,mid,x,mid,val);
tr[q].rc=update(tr[p].rc,mid+1,r,mid+1,y,val);
}
return q;
}
int query(int p,int q,int l,int r,int x,int y)
{
if(x<=l&&r<=y)
{
return tr[q].sum-tr[p].sum;
}
int ans=(tr[q].lz-tr[p].lz)*(y-x+1);
int mid=l+r>>1;
if(x>mid) ans+=query(tr[p].rc,tr[q].rc,mid+1,r,x,y);
else if(y<=mid) ans+=query(tr[p].lc,tr[q].lc,l,mid,x,y);
else{
ans+=query(tr[p].lc,tr[q].lc,l,mid,x,mid);
ans+=query(tr[p].rc,tr[q].rc,mid+1,r,mid+1,y);
}
return ans;
}
int ans;
void dfs(int u,int fa,int pre,int res)
{ ans=max(ans,res);
for(auto it:e1[u])
{ if(it==fa) continue;
int l=first[it],r=first[it]+sz[it]-1;
int x=query(root[pre],root[tim],1,n,l,r);
if(x>0)
{ int cnt=pre;
while(query(root[cnt],root[tim],1,n,l,r)>0)
{
cnt++;
}
++tim;
root[tim]=update(root[tim-1],1,n,l,r,1);
int cha=cnt-pre;
dfs(it,u,cnt,res-cha+1);
tim--;
}
else{
++tim;
root[tim]=update(root[tim-1],1,n,l,r,1);
dfs(it,u,pre,res+1);
tim--;
}
}
return;
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
e1[i].clear(),e2[i].clear(),sz[i]=0;
for(int i=0;i<n-1;i++)
{
int a,b;
scanf("%d%d",&a,&b);
e1[a].push_back(b);
e1[b].push_back(a);
}
for(int i=0;i<n-1;i++)
{
int a,b;
scanf("%d%d",&a,&b);
e2[a].push_back(b);
e2[b].push_back(a);
}
top=0;
getdfn(1,-1);
idx=tim=0;
root[0]=build(1,n);
ans=1;
tim++;
root[1]=update(root[0],1,n,first[1],first[1]+sz[1]-1,1);
dfs(1,-1,0,1);
cout<<ans<<endl;
}
return 0;
}
hdu4348
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN = 1e5 + 10;
const int MAXM = 1e6 + 10;
const int INF=0x3f3f3f3f;
const int LINF=0x3f3f3f3f3f3f3f3f;
const int NINF=-INF-1;
const int mod=1e9+7;
#define PI acos(-1.0)
int a[MAXN],n,q;
int root[MAXN],idx;
struct node{
int lc,rc;
int sum,lz;
}tr[MAXN*30];
int build(int l,int r)
{
int p=++idx;
tr[p].sum=tr[p].lz=0;
if(l==r)
{ tr[p].sum=a[l];
return p;
}
int mid=l+r>>1;
tr[p].lc=build(l,mid),tr[p].rc=build(mid+1,r);
tr[p].sum=tr[tr[p].lc].sum+tr[tr[p].rc].sum;
return p;
}
int update(int p,int l,int r,int x,int y,int val)
{ int q=++idx;
tr[q]=tr[p];
tr[q].sum=tr[p].sum+val*(y-x+1);
if(x==l&&y==r)
{
tr[q].lz+=val;
return q;
}
int mid=l+r>>1;
if(x>mid) tr[q].rc=update(tr[p].rc,mid+1,r,x,y,val);
else if(y<=mid) tr[q].lc=update(tr[p].lc,l,mid,x,y,val);
else{
tr[q].lc=update(tr[p].lc,l,mid,x,mid,val);
tr[q].rc=update(tr[p].rc,mid+1,r,mid+1,y,val);
}
return q;
}
int query(int p,int l,int r,int x,int y)
{
if(x<=l&&r<=y)
{
return tr[p].sum;
}
int ans=tr[p].lz*(y-x+1);
int mid=l+r>>1;
if(x>mid) ans+=query(tr[p].rc,mid+1,r,x,y);
else if(y<=mid) ans+=query(tr[p].lc,l,mid,x,y);
else{
ans+=query(tr[p].lc,l,mid,x,mid);
ans+=query(tr[p].rc,mid+1,r,mid+1,y);
}
return ans;
}
signed main()
{
while(scanf("%lld%lld",&n,&q)==2)
{
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
idx=0;
root[0]=build(1,n);
int time=0;
char s[15];
while(q--)
{ int l,r;
scanf("%s",s);
if(s[0]=='Q')
{
scanf("%lld%lld",&l,&r);
printf("%lld\n",query(root[time],1,n,l,r));
}
else if(s[0]=='C')
{
int val;
scanf("%lld%lld%lld",&l,&r,&val);
++time;
root[time]=update(root[time-1],1,n,l,r,val);
}
else if(s[0]=='H')
{
int t;
scanf("%lld%lld%lld",&l,&r,&t);
printf("%lld\n",query(root[t],1,n,l,r));
}
else{
scanf("%lld",&time);
idx=root[time+1];
}
}
}
return 0;
}