Description
Solution
一开始这道题就看错题了,我直接用一轮的期望作为下一轮的值,结果还以为很容易就能用n^3搞出来,结果搞了半天。
因为最后序列的答案不会超过原序列的最大值,所以我们可以考虑对原序列离散化一下,然后考虑每个位置最后的值是原序列第k大的期望,那么我们可以设sum[i][j]表示第i个点值小于等于序列中第j大的方案是多少,然后用sum[i][j]-sum[i][j-1]就是当前的位置的值等于序列中的第j大的方案数是多少。
我们要考虑怎么求这个值。
那么我们可以求第i轮,[L,R]区间的值小于等于x的方案数,x是当前要求区间的最大值。
我们对于一个x,可以用一个极大的区间[l,r]来表示,就是当前[l,r]区间的最大值是x,那么l-1或r+1的值大于x。这个可以预处理出来。
然后对于第k轮的极大区间[l,r],我们来分类讨论一下上一轮的极大区间的可能的情况。
1、上一次的极大区间是[u,j],那么就是说我们的操作区间是[1~u-1,i-1],这样可以让[u,i-1]这段的值提升,让极大区间缩小(现在考虑的极大区间与x无关,只是单纯的考虑极大区间,就是l-1和r+1的值都是大于中间的最大值的)
2、极大区间是[i,v],和上面类似
3、极大区间是[i,j],那么可以操作一段[u,v]这段[u,v]可以完全被[i,j]包含或或者不相交,那么这个方案数是可以直接算的。
然后可以通过上面算出的f来算出sum值。
注意求sum值的时候,因为数字不是连续的,所以有一些值是0,要找到前面第一个不是0的地方减掉。
然后还要卡一下常,注意在n^3里面是不能带mod的,这样会很慢。
在随机数据下这个方法是n^3的,理论复杂度虽然n^4……
Code
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<math.h>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fod(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int maxn=507,mo=998244353;
ll i,j,k,t,n,m,ans,p,q,num;
ll f[2][maxn][maxn],sum[maxn][maxn],b[maxn],l[maxn],r[maxn],c[maxn],d[maxn];
struct node{
ll x,y;
}a[maxn];
bool cmp(node x,node y){return x.x<y.x;}
ll qsm(ll x,ll y){
ll z=1;
for(;y;y/=2,x=x*x%mo)if(y&1)z=z*x%mo;
return z;
}
void solve(int x,int L,int R){
int i,j,k,u,v;
ll t;
p=0;
fo(i,L,R)fo(j,i,R)f[0][i][j]=0;
f[0][L][R]=1;
fo(k,1,m){
q=(p^1);
fo(j,L,R){
t=0;
fo(i,L,j){
f[q][i][j]=t;
t=t+f[p][i][j]*(i-1);
}
}
fo(i,L,R){
t=0;
fod(j,R,i){
f[q][i][j]=(f[q][i][j]+t+(b[i-1]+b[n-j]+b[j-i+1])*f[p][i][j])%mo;
t=t+f[p][i][j]*(n-j);
}
}
p=q;
}
fo(i,L,R){
t=0;
fod(j,R,i){
t=(t+f[p][i][j])%mo;
sum[j][x]=(sum[j][x]+t)%mo;
}
}
}
int main(){
scanf("%lld%lld",&n,&m);
fo(i,1,n)scanf("%lld",&a[i].x),a[i].y=i,b[i]=i*(i+1)/2;
sort(a+1,a+1+n,cmp);
num=1;d[a[1].y]=1;c[1]=a[1].x;
fo(i,2,n){
if(a[i].x!=a[i-1].x)++num;
d[a[i].y]=num,c[num]=a[i].x;
}
fo(i,1,n){j=i;while(j>1&&d[i]>=d[j-1])j=l[j-1];l[i]=j;}
fod(i,n,1){j=i;while(j<n&&d[i]>=d[j+1])j=r[j+1];r[i]=j;}
fo(i,1,n)solve(d[i],l[i],r[i]);
fo(i,1,n){
ans=0;t=0;
fo(j,1,num){
if(!sum[i][j])continue;
sum[i][j]=sum[i][j]-t;
ans=(ans+sum[i][j]*c[j]%mo)%mo;
t=(t+sum[i][j])%mo;
}
ans=(ans+mo)%mo;
ans=ans*qsm(qsm(n*(n+1)/2,m)%mo,mo-2)%mo;
printf("%lld\n",ans);
}
}