一、题目
给你一个数组 a a a,其中有 n n n 个元素。你现在要标记若干个元素,使得它们单调上升。你按照下面两个步骤来进行操作:
s t e p 1 \tt step1 step1:如果一个元素称为好元素,当且仅当它被标记以后,标记过的元素仍然是单调上升的。设有 k k k 个好元素。
s t e p 2 \tt step2 step2:从 k k k 个好元素中随机选一个进行标记,然后跳到 s t e p 1 \tt step1 step1,如果 k = 0 k=0 k=0,则将所有未标记的元素删除,结束操作。
问最后期望的元素个数是多少?取模 1 e 9 + 7 1e9+7 1e9+7
数据规模: n < = 2000 , a i < = n , a i n<=2000,a_i<=n,a_i n<=2000,ai<=n,ai 均不相等
二、解法
抓住问题的一个背景:我们的好元素是定义在单调上升序列的基础上的,基于这一点和
n
≤
2000
n\leq 2000
n≤2000,我们可以尝试用区间
d
p
dp
dp 解决这道题,设
f
[
i
]
[
j
]
f[i][j]
f[i][j] 为
[
i
,
j
]
[i,j]
[i,j] 的期望好元素个数,转移枚举
k
k
k:
f
[
i
]
[
j
]
=
∑
k
f
[
i
]
[
k
]
+
f
[
k
]
[
j
]
c
n
t
k
+
1
f[i][j]=\frac{\sum_{k}f[i][k]+f[k][j]}{cnt_k}+1
f[i][j]=cntk∑kf[i][k]+f[k][j]+1
k
k
k 需要满足两个条件:
i
<
k
<
j
,
a
[
i
]
<
k
<
a
[
j
]
i<k<j,a[i]<k<a[j]
i<k<j,a[i]<k<a[j],这其实是一个二维偏序问题。由于第一个偏序关系是下标,所以我们可以通过安排恰当的
d
p
dp
dp 顺序来解决它,第二个偏序关系就用树状数组之类的数据结构来做,时间复杂度
O
(
n
2
log
n
)
O(n^2\log n)
O(n2logn) 。
具体来说,我们从大到小枚举 i i i,从小到大枚举 j j j,对于每个 i i i 和 j j j 我们都维护一个树状数组(其实就是分开算转移中的两个 f f f ),处理 f [ i ] [ j ] f[i][j] f[i][j] 的时候加入 d p [ i ] [ j − 1 ] dp[i][j-1] dp[i][j−1] 和 d p [ i + 1 ] [ j ] dp[i+1][j] dp[i+1][j]
观察我们的 d p dp dp,其实需要建立两个虚点 a [ 0 ] = 0 , a [ n + 1 ] = n + 1 a[0]=0,a[n+1]=n+1 a[0]=0,a[n+1]=n+1,答案就是 d p [ 0 ] [ n + 1 ] dp[0][n+1] dp[0][n+1]
#include <cstdio>
const int M = 2005;
const int MOD = 1e9+7;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,a[M],dp[M][M];
struct node
{
int b1[M],b2[M];
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int y)//a[x]处加入点值y
{
for(int i=x;i<=n;i+=lowbit(i))
{
b1[i]++;
b2[i]=(b2[i]+y)%MOD;
}
}
int ask1(int x)
{
if(x<=0) return 0;
int res=0;
for(int i=x;i>0;i-=lowbit(i))
res+=b1[i];
return res;
}
int ask2(int x)
{
if(x<=0) return 0;
int res=0;
for(int i=x;i>0;i-=lowbit(i))
res=(res+b2[i])%MOD;
return res;
}
int q1(int l,int r)
{
return ask1(r-1)-ask1(l);
}
int q2(int l,int r)
{
return ask2(r-1)-ask2(l);
}
}A[M],B[M];
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
signed main()
{
n=read();
for(int i=1;i<=n;i++)
a[i]=read();
a[n+1]=n+1;n++;
for(int i=n;i>=0;i--)
for(int j=i+2;j<=n;j++)
{
if(a[i]<a[j-1]) A[i].add(a[j-1],dp[i][j-1]);
if(a[i+1]<a[j]) A[j].add(a[i+1],dp[i+1][j]);
if(a[i]>=a[j]) continue;
int inv=A[i].q1(a[i],a[j]);
if(inv)//这样才能转移嘛,要不1不能加上去
dp[i][j]=(A[i].q2(a[i],a[j])+A[j].q2(a[i],a[j]))*qkpow(inv,MOD-2)%MOD+1;
}
printf("%lld\n",(dp[0][n]+MOD)%MOD);
}