树上差分初(luan)学(xie)
前两天cy学长讲了简单树论,主要是lca,dfs序,树上差分。其实还讲了其他内容不过他没讲明白 我现在真正会用的也就这些=.=。
树上差分需要一些预备知识,即求树上两个节点的最近公共祖先(LCA),另外还有要对一个序列的差分(或者还有前缀和)有一定的了解。
有这样一类问题,在一棵树上多次对某两点之间的路或者路上的所有节点进行操作,最后要求对整棵树求某个答案。
一般的差分维护要解决一个序列上的这种多次区间操作问题,核心思想就是对每次操作只考虑区间的端点,这样每次操作的复杂度为O(1)。如果要操作的区间变为了树上的路,我们也可以考虑对路的两个端点操作然后这样就WA了hhhhh 。
树上差分分为两种,点差分和边差分(为了方便自己起了个名)。
又又又为了方便考虑,我们把一条路<u,v>拆成两条链,一条链为u -> lca(u,v),另一条就是lca(u,v) -> v。
(先讲点差分)
用num[]作为差分数组,如果对<u,v>上的所有点,其某个值(取决于实际问题)都要+1,那我们在num[]上的实际操作为num[u]++,num[v]++, num[lca(u,v)]–, num[fa[lca(u,v)]–。考虑我们拆出来的第一条链,要对上面的所有点+1,就相当于一般的差分,对端点u,num[u]++,而序列另一端点lca(u,v),我们实际的操作是num[fa[lca(u,v)]]–。再看另一条链,其实上面拆分时说得不太精确,由于lca(u,v)已经被算到了第一条链中,所以它不在这条链中作为端点考虑,而是作为除v外另一端点(即链上lca(u,v)的儿子)的外点,所以实际操作为num[v]++,num[lca(u,v)]–。
(再讲边差分)
而所谓的边差分就是如果要对路上的所有边上某个值+1,实际操作为num[u]++,num[v]++,num[lca(u,v)]-=2。这是因为现在num[i]保存的是关于i与i的父亲相连的那条边的信息,而且这条边不包含在我们所说的那两条链中,所以对两条链分别考虑,有num[u]++,num[lca(u,v)]–; num[v]++,num[lca(u,v)]–。
操作结束后,如果要得到所有操作后每个节点(或每条边)的答案,需要一个dfs将num[]数组累加一下。
两个例题:点差分 LuoguP3128
边差分POJ 3417
点差分题:给一棵n个节点的树,m次操作,每次操作对一条路上所有节点的运输压力+1,求最后树上运输压力最大的点其运输压力是多少。
解法:点差分板子题,处理完后累加num[]数组,然后扫一遍所有点求max就好了。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <cstdlib>
#define ll long long
using namespace std;
const int maxn=50005;
int f[maxn][21];
ll ans;
int h[maxn*2],nxt[maxn*2],v[maxn*2],p,pre[maxn];//链式前向星存图,pre[i]存i的父亲
ll num[maxn];
int n,m;
int depth[maxn];
int getnum()//读入优化,不加有可能会tle
{
int num=0;
bool f=1;
char c=getchar();
while(!isdigit(c))
{
if(c=='-')f=0;
c=getchar();
}
while(isdigit(c))
{
num=num*10+c-'0';
c=getchar();
}
return f ? num: -num;
}
void add(int a,int b)
{
++p;
nxt[p]=h[a];
h[a]=p;
v[p]=b;
}
void dfs(int x,int fa,int d)
{
depth[x]=d;
pre[x]=fa;
for(int i=h[x];i!=0;i=nxt[i])
{
if(v[i]==fa)
{
continue;
}
f[v[i]][0]=x;
dfs(v[i],x,d+1);
}
return;
}
int lca(int x,int y)//倍增法求lca,十分易懂,感谢高中时期郑公子的指导
{
if(depth[x] > depth[y])
{
int t=x;
x=y;
y=t;
}
for(int i=19;i>=0;i--)
{
if(depth[f[y][i]]>=depth[x])
{
y=f[y][i];
}
}
if(x==y)
{
return x;
}
for(int i=19;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
void work(int x,int fa)//再次dfs累加num[]
{
for(int i=h[x];i!=0;i=nxt[i])
{
int vv=v[i];
if(vv==fa)continue;
work(vv,x);
num[x]+=num[vv];
}
}
int main()
{
n=getnum();
m=getnum();
for(int i=1;i<n;i++)
{
int a=getnum();
int b=getnum();
add(a,b);
add(b,a);
}
dfs(1,0,1);
for(int i=1;i<=19;i++)
{
for(int j=1;j<=n;j++)
{
f[j][i]=f[f[j][i-1]][i-1];
}
}
for(int i=1;i<=m;i++)
{
int a=getnum();
int b=getnum();
num[a]++;
num[b]++;
num[lca(a,b)]--;
num[pre[lca(a,b)]]--;//点差分操作
}
work(1,0);
for(int i=1;i<=n;i++)
{
ans=max(ans,num[i]);
}
cout<<ans;
return 0;
}
第二题是个有点脑洞的边差分题,大意是先给一棵树,然后往树上加边,问删掉一条原来的边再删掉一条新边使原来的树变为两个之间不连通的子图的方案数。
因为我们不会树链剖分 要用边差分做,所以要想个合理的办法OmO。
num[i]表示i与其父亲相连的原边被新图上的环覆盖的次数,有三种情况:
Ⅰ:num[i]=0,这条边没有被包括在圈里,断开就能使图不连通,不过还必须断开一条新边,这样断开m条中的任意一条都是可以的,方案数+=m;
Ⅱ:num[i]=1,这条边包括在了一个圈里,或者说加了一条新边使其出现了一个圈,所以要断开这条边和那条新边,方案数+1;
Ⅲ:num[i]>1,这条边所在的圈至少加了大于一条新边(而且产生了不止一个圈),那至少要断开所有这些新边才能使图不连通,这是不合法的,答案+0。
然后边差分,然后改bug。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <cstdlib>
#define ll long long
using namespace std;
const int maxn=100005;
int f[maxn][21];
ll ans;
int h[maxn*2],nxt[maxn*2],v[maxn*2],p;
ll num[maxn];
int n,m;
int depth[maxn];
int getnum()
{
int num=0;
bool f=1;
char c=getchar();
while(!isdigit(c))
{
if(c=='-')f=0;
c=getchar();
}
while(isdigit(c))
{
num=num*10+c-'0';
c=getchar();
}
return f ? num: -num;
}
void add(int a,int b)
{
++p;
nxt[p]=h[a];
h[a]=p;
v[p]=b;
}
void dfs(int x,int fa,int d)
{
depth[x]=d;
for(int i=h[x];i!=0;i=nxt[i])
{
if(v[i]==fa)
{
continue;
}
f[v[i]][0]=x;
dfs(v[i],x,d+1);
}
return;
}
int lca(int x,int y)
{
if(depth[x] > depth[y])
{
int t=x;
x=y;
y=t;
}
for(int i=19;i>=0;i--)
{
if(depth[f[y][i]]>=depth[x])
{
y=f[y][i];
}
}
if(x==y)
{
return x;
}
for(int i=19;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
void work(int x,int fa)
{
for(int i=h[x];i!=0;i=nxt[i])
{
int vv=v[i];
if(vv==fa)continue;
work(vv,x);
num[x]+=num[vv];
}
}
int main()
{
n=getnum();
m=getnum();
for(int i=1;i<n;i++)
{
int a=getnum();
int b=getnum();
add(a,b);
add(b,a);
}
dfs(1,0,1);
for(int i=1;i<=19;i++)
{
for(int j=1;j<=n;j++)
{
f[j][i]=f[f[j][i-1]][i-1];
}
}
for(int i=1;i<=m;i++)
{
int a=getnum();
int b=getnum();
num[a]++;
num[b]++;
num[lca(a,b)]-=2;
}
work(1,0);
for(int i=2;i<=n;i++)
{
if(num[i]==0)
{
ans+=m;
}
if(num[i]==1)
{
ans++;
}
}
cout<<ans;
return 0;
}
第一次写博客,欢迎指正~