给一颗有根树,根节点为1,再给定一个排列,长度为n,要求将排列切分成K段,定义每段的价值为该排列所有点及两两点之间lca中最浅节点的深度。要求输出K段区间所有可能的价值和中的最小值。n*K<=3e5。
解法:
很明显可以往dp方向思考。定义dp[i][j]为前i位切分成j段的价值和的最小值。在列出转移方程前,先说明观察到的若干性质:
性质1. 定义“一段排列所有点及两两点之间lca中最浅节点的深度”为T,当在排列末尾加上一个节点ai的时候,只需要求一下ai-1与ai的lca,再与之前的lca比较谁的深度小,维护深度的最小值即可。
这个性质很好拿数学归纳法证明:
a) 朴素情况是只有一个点,T就是自己的深度。两个点,T就是它们lca的深度。
b)假设长度为k的区间K可以这么维护T,定义该区间最浅深度的节点为tnode,加入一个新的节点ak+1时只有两种情况:
i) 该节点位于tnode为根的子树内,tnode依然是最浅的节点;
ii) 该节点位于tnode为根的子树外,则与区间K中任意一点求lca,等价于与tnode求lca,即lca(ak,ak+1)=lca(tnode,ak+1)。故依然用该方法可以维护长度为k+1区间的T值。
证毕。
性质2. 处理出相邻点间lca深度之后,比如 7 4 5 6 3 9 10,若最优切分方式里 4 与 3 在不同区间里,4 与 3 之间的任何数,即 5 6,划分给 3 区间或者 4 区间,都不会影响最终答案。
性质3. 区间末尾增加新的节点时,价值T一定是不增的。
细细思考这些性质之后,,,,就有:
在区间新加入一个点时,只需要按2种情况转移新加入的点即可:
a) 将该点放到前一个区间里,dp[i][j]=dp[i-1][j];
b) 将该点放到下一个区间里,dp[i][j]=dp[i-1][j-1]+depth[j] 或者 dp[i][j]=dp[i-2][j-1]+depth[lca(i-1,i)]。
这些情况全拿min维护一下就行啦,一切都放到dp数组里去转移了。。。
#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> pii;
const int maxn=300005;
int n,K,a[maxn],f[maxn],dep[maxn],dp[3][maxn],lca[maxn];
vector<pii> ask[maxn];
vector<int> G[maxn];
void init() {
for (int i=1;i<=n;++i) {
f[i]=i;
G[i].clear();
ask[i].clear();
}
for (int i=1;i<n;++i) {
ask[a[i]].push_back(pii(i+1,a[i+1]));
ask[a[i+1]].push_back(pii(i+1,a[i]));
}
memset(dep,0,(n+1)*sizeof dep[0]);
}
int F(int x) {
return x==f[x]?x:f[x]=F(f[x]);
}
inline void U(int a,int b) {
f[F(b)]=F(a);
}
void dfs(int u,int d) {
dep[u]=d;
for (int i=0;i<(int)ask[u].size();++i) {
int j=ask[u][i].second,p=ask[u][i].first;
if (dep[j])
lca[p]=dep[F(j)];
}
for (int i=0;i<(int)G[u].size();++i) {
int v=G[u][i];
if (dep[v])
continue;
dfs(v,d+1);
U(u,v);
}
}
int main()
{
while (scanf("%d%d",&n,&K)==2) {
for (int i=1;i<=n;++i)
scanf("%d",&a[i]);
init();
for (int i=0;i<n-1;++i) {
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,1);
memset(dp[0],0x3f,(n+1)*sizeof dp[0][0]);
dp[0][0]=0;
for (int i=1;i<=n;++i) {
dp[i%3][0]=0;
for (int j=1;j<=K;++j) {
dp[i%3][j]=min(dp[(i-1)%3][j],dp[(i-1)%3][j-1]+dep[a[i]]);
if (i>=2)
dp[i%3][j]=min(dp[i%3][j],dp[(i+1)%3][j-1]+lca[i]);
}
}
printf("%d\n",dp[n%3][K]);
}
return 0;
}