二叉搜索树
题解
首先,我们发现对于任何一棵子树内,他所包含带你的权值映射到我们所有加入了树内的点一定是连续的。
所以如果我们倒着来,每次将两棵子树接在一个根上,接出来的根的值域也一定是连续的,我们可以考虑区间dp。
对于顺序固定的点,我们明显可以直接知道它两边的子树是哪些,直接连起来即可。
对于那些顺序不顾定的点,我们就只能通过区间dp进行处理。
但事实上,我们每次
d
p
dp
dp的必须是一个连续段,也就是说,我们总共的
d
p
dp
dp总共会被按照
[
1
,
l
)
[1,l)
[1,l)建出来的树分为多个小段,我们要对这些小段内部分别进行区间dp。
我们不妨将
[
l
,
r
]
[l,r]
[l,r]中的点全部按照大小排序,记
d
p
l
,
r
dp_{l,r}
dpl,r表示值域范围为
[
a
l
,
a
r
]
[a_{l},a_{r}]
[al,ar]的点所构成的子树的深度和,容易得到转移方程式
d
p
l
,
r
=
max
m
i
d
=
l
r
d
p
l
,
m
i
d
−
1
+
d
p
m
i
d
+
r
,
r
+
r
−
l
dp_{l,r}=\max_{mid=l}^{r}dp_{l,mid-1+dp_{mid+r,r}}+r-l
dpl,r=mid=lmaxrdpl,mid−1+dpmid+r,r+r−l
但显然这样是不完全的,因为我们要考虑后面那些顺序固定的点的贡献,也就是说,如果我们枚举的根恰好是左右端点之一,我们要将它那一侧的顺序固定的点所建成的子树大小加进去。
求出这个这一段区间的
d
p
dp
dp值后,我们可以随便钦定一个节点作为根,然后继续处理前面那一段。
至于维护某个节点所处子树的根,我们可以用并查集进行维护。
时间复杂度 O ( α ( n ) n + ( r − l ) 3 ) O\left(\alpha(n)n+(r-l)^3\right) O(α(n)n+(r−l)3)。
源码
#include<bits/stdc++.h>
using namespace std;
#define MAXN 100005
#define lowbit(x) (x&-x)
#define reg register
#define pb push_back
#define mkpr make_pair
#define fir first
#define sec second
#define lson (rt<<1)
#define rson (rt<<1|1)
typedef long long LL;
typedef unsigned long long uLL;
const LL INF=0x3f3f3f3f3f3f3f3f;
const int mo=1e9+7;
const int inv2=499122177;
const int jzm=2333;
const int zero=10000;
const int lim=30000;
const int orG=3,invG=332748118;
const double Pi=acos(-1.0);
const double eps=1e-5;
typedef pair<int,int> pii;
template<typename _T>
_T Fabs(_T x){return x<0?-x:x;}
template<typename _T>
void read(_T &x){
_T f=1;x=0;char s=getchar();
while(s>'9'||s<'0'){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
template<typename _T>
void print(_T x){if(x<0){x=(~x)+1;putchar('-');}if(x>9)print(x/10);putchar(x%10+'0');}
LL gcd(LL a,LL b){return !b?a:gcd(b,a%b);}
int add(int x,int y,int p){return x+y<p?x+y:x+y-p;}
void Add(int &x,int y,int p){x=add(x,y,p);}
int qkpow(int a,int s,int p){int t=1;while(s){if(s&1LL)t=1ll*a*t%p;a=1ll*a*a%p;s>>=1LL;}return t;}
int n,b[MAXN],fa[MAXN],L[MAXN],R[MAXN],l,r;LL dp[MAXN],g[405][405];
int findSet(int x){return fa[x]==x?x:fa[x]=findSet(fa[x]);}
signed main(){
read(n);
for(int i=1;i<=n;i++)read(b[i]);read(l);read(r);
for(int i=n;i>r;i--){
fa[b[i]]=b[i];L[b[i]]=b[i];R[b[i]]=b[i];dp[b[i]]=1;
if(fa[b[i]-1]){
int x=findSet(b[i]-1);L[b[i]]=L[x];
dp[b[i]]+=dp[x]+1ll*(R[x]-L[x]+1);fa[x]=b[i];
}
if(fa[b[i]+1]){
int x=findSet(b[i]+1);R[b[i]]=R[x];
dp[b[i]]+=dp[x]+1ll*(R[x]-L[x]+1);fa[x]=b[i];
}
}
sort(b+l,b+r+1);
for(int i=l,j;i<=r;i=j+1){
j=i;
while(j<r){
if(b[j+1]==b[j]+1){j++;continue;}
if(!fa[b[j]+1]||!fa[b[j+1]-1])break;
if(findSet(b[j]+1)==findSet(b[j+1]-1)){j++;continue;}
break;
}
for(int len=1;len<=j-i+1;len++)
for(int li=1,ri=len;ri<=j-i+1;li++,ri++){
g[li][ri]=INF;
for(int mid=li;mid<=ri;mid++){
LL tmp=0;int lt=b[li+i-1]-1,rt=b[ri+i-1]+1;
if(li==mid){if(fa[lt])tmp+=dp[findSet(lt)];}else tmp+=g[li][mid-1];
if(ri==mid){if(fa[rt])tmp+=dp[findSet(rt)];}else tmp+=g[mid+1][ri];
g[li][ri]=min(g[li][ri],tmp);
}
int lt=fa[b[li+i-1]-1]?L[findSet(b[li+i-1]-1)]:b[li+i-1];
int rt=fa[b[ri+i-1]+1]?R[findSet(b[ri+i-1]+1)]:b[ri+i-1];
g[li][ri]+=1ll*(rt-lt+1);
}
for(int k=i;k<=j;k++){
fa[b[k]]=b[k];L[b[k]]=b[k];R[b[k]]=b[k];
if(fa[b[k]-1]){int x=findSet(b[k]-1);fa[x]=b[k];L[b[k]]=L[x];}
if(fa[b[k]+1]){int x=findSet(b[k]+1);fa[x]=b[k];R[b[k]]=R[x];}
}
dp[findSet(b[j])]=g[1][j-i+1];
}
for(int i=l-1;i>0;i--){
fa[b[i]]=b[i];L[b[i]]=b[i];R[b[i]]=b[i];dp[b[i]]=1;
if(fa[b[i]-1]){
int x=findSet(b[i]-1);L[b[i]]=L[x];
dp[b[i]]+=dp[x]+1ll*(R[x]-L[x]+1);fa[x]=b[i];
}
if(fa[b[i]+1]){
int x=findSet(b[i]+1);R[b[i]]=R[x];
dp[b[i]]+=dp[x]+1ll*(R[x]-L[x]+1);fa[x]=b[i];
}
}
printf("%lld\n",dp[findSet(b[1])]);
return 0;
}