bzoj4543 长链剖分优化树形DP
题意:
树上统计三元组(i,j,k)两两距离相等的数目
思路:
看起来有两种,但合并后其实只有一种
前面两种其实只是 j = 0 和 j = d j=0和j=d j=0和j=d的特例而已
想清楚这点非常重要,这决定我们能否顺利推出方程,所以我们需要的就是两个东西
f [ i ] [ j ] f[i][j] f[i][j]表示以i为根距离i距离为j的有多少个,容易想到转移 f [ u ] [ j ] + = f [ v ] [ j − 1 ] f[u][j]+=f[v][j-1] f[u][j]+=f[v][j−1],
g [ i ] [ j ] g[i][j] g[i][j]表示以i为根的子树中有多少个点对(x,y)满足x,y到其lca距离为d,lca->i距离为d-j
g [ i ] [ j ] 比 起 g [ v ] g[i][j]比起g[v] g[i][j]比起g[v],其实就是在原来上移了一个点,所以减的变少了,所以
g [ u ] [ i ] + = g [ v ] [ i + 1 ] g[u][i]+=g[v][i+1] g[u][i]+=g[v][i+1]
但是我们一定要注意,在做树形dp的时候,从子树转移的话是每个儿子节点逐一转移过来的,包括了当前已经被儿子节点转移过的根节点和当前还没被转移的儿子节点,也就是其实是这样的,我们转移的时候要考虑两点,一个是仅由当前子节点贡献,一个是由当前子节点和当前节点(被画起来)共同贡献,当然这个可能没有贡献,但在这题中, g [ u ] [ j ] g[u][j] g[u][j]是有的, g [ u ] [ j + 1 ] + = f [ u ] [ j + 1 ] ∗ f [ v ] [ j ] g[u][j+1]+=f[u][j+1]*f[v][j] g[u][j+1]+=f[u][j+1]∗f[v][j]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEfDw4xG-1615809410197)(C:\Users\98753\AppData\Roaming\Typora\typora-user-images\image-20210315174813078.png)]
那么fg有了,答案怎么统计呢,其实这题就是个分配问题,3点可以分配为1,2或者2,1,
通过乘法原理保证不重
a n s + = g [ x ] [ 0 ] ans+=g[x][0] ans+=g[x][0]
a n s + = f [ u ] [ j ] ∗ g [ v ] [ j + 1 ] ans+=f[u][j]*g[v][j+1] ans+=f[u][j]∗g[v][j+1]
a n s + = g [ u ] [ j ] ∗ f [ v ] [ j − 1 ] ans+=g[u][j]*f[v][j-1] ans+=g[u][j]∗f[v][j−1]
还有需要注意的是由于g是倒序,所以f得加两倍空间
具体的已经写在代码上了
需要注意的是枚举的范围
时空复杂度 O ( n ) O(n) O(n)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
ll *f[maxn],*g[maxn],tmp[maxn*3],*id=tmp,ans;
int md[maxn],dep[maxn],hson[maxn],head[maxn],next1[maxn<<1],ver[maxn<<1],tot,n;
void add(int x,int y){
ver[++tot]=y,next1[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int f){
for(int i=head[x];i;i=next1[i]){
int y=ver[i];
if(y==f)continue;
dfs1(y,x);
if(md[y]>md[hson[x]])hson[x]=y;
}
md[x]=md[hson[x]]+1;//md表示x的高度 叶子为1
}
void dfs(int x,int fa){
if(hson[x])
f[hson[x]]=f[x]+1,g[hson[x]]=g[x]-1,dfs(hson[x],x);
f[x][0]=1;ans+=g[x][0];//重儿子统计
for(int i=head[x];i;i=next1[i]){
int y=ver[i];
if(y==fa||y==hson[x])continue;
f[y]=id;id+=md[y]<<1;g[y]=id;id+=md[y];
dfs(y,x);
for(int j=0;j<md[y];++j){
if(j)
ans+=f[x][j-1]*g[y][j];//轻儿子在此统计了g[y][1]即对g[x][0]的贡献
ans+=g[x][j+1]*f[y][j];//最多到md[y]-1
}
for(int j=0;j<md[y];++j){
g[x][j+1]+=f[x][j+1]*f[y][j];
if(j)
g[x][j-1]+=g[y][j];
f[x][j+1]+=f[y][j];//最多可以更新到md[y]
}
}
}
int main(){
scanf("%d",&n);
int a,b;
for(int i=1;i<n;++i){
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
dfs1(1,0);
f[1]=id;id+=md[1]<<1;
g[1]=id;id+=md[1];
dfs(1,0);
cout<<ans<<"\n";
return 0;
}