Description
小Yuuka遇到了一个题目:有一个序列a_1,a_2,?,a_n,q次操作,每次把一个区间内的数改成区间内的最大值,问
最后每个数是多少。小Yuuka很快地就使用了线段树解决了这个问题。于是充满智慧的小Yuuka想,如果操作是随机
的,即在这q次操作中每次等概率随机地选择一个区间[l,r](1≤l≤r≤n),然后将这个区间内的数改成区间内最大
值(注意这样的区间共有(n(n+1))/2个),最后每个数的期望大小是多少呢?小Yuuka非常热爱随机,所以她给出
的输入序列也是随机的(随机方式见数据规模和约定)。对于每个数,输出它的期望乘((n(n+1))/2)^q再对10
^9+7取模的值。
Input
第一行包含2个正整数n,q,表示序列里数的个数和操作的个数。接下来1行,包含n个非负整数a1,a2...an。N<=400,Q<=400
Output
输出共1行,包含n个整数,表示每个数的答案
Sample Input
5 5
1 5 2 3 4
1 5 2 3 4
Sample Output
3152671 3796875 3692207 3623487 3515626
HINT
Source
DP
首先离散化
考虑sum[i][j]表示i这个位置变成j的方案数
这样不好做,我们考虑把等于j变成小于等于j
对于每一个数x,我们考虑它能影响的区间,假设是[L,R],那么a[L-1]>a[x],a[R+1]>a[x],区间中其他数都小于等于x
然后考虑dp[k][i][j]表示k次操作后<=x的极长区间为[i,j]的方案数
那么有三种情况
第一,区间被缩短了,假设是由[l,j]转移过来的,那么当前选的区间一定是[1到l-1都可以,i-1],另一边同理
第二,区间长度不变,那么就是选择的区间一定被分成了3块,[1,i-1],[i,j],[j+1,n]统计方案数之和即可
复杂度是n^2*q的
复杂度证明:由于是随机数据,值域>>n
,我们可以先钦定每个数都不同,可以认为是一个排列
我们的复杂度是(每个极长区间长度)^2*q的
一个长度为L的极长区间产生贡献的概率
C(n,L + 2) * 2 * L! * (n - L - 2)! / n!,贡献乘上L^2
表示把区间和区间外的两个点提取出来,然后从n个数选出L+2个数,最大的两个值分配给两端,中间和区间外的随意
化简一下发现是L^2 / (L + 1) * (L + 2)<1
区间总个数是n^2的,所以复杂度是这个
#include <bits/stdc++.h>
#define xx first
#define yy second
#define mp make_pair
#define pb push_back
#define fill( x, y ) memset( x, y, sizeof x )
#define copy( x, y ) memcpy( x, y, sizeof x )
using namespace std;
typedef long long LL;
typedef pair < int, int > pa;
const int MAXN = 405;
const int mod = 1e9 + 7;
int x[MAXN], y[MAXN], rk[MAXN], n, q, cnt[MAXN];
LL dp[2][MAXN][MAXN], sum[MAXN][MAXN];
inline bool cmp(int a, int b) { return x[ a ] < x[ b ]; }
inline void solve(int l, int r, int now)
{
for( int i = l ; i <= r ; i++ )
for( int j = i ; j <= r ; j++ )
dp[ 0 ][ i ][ j ] = dp[ 1 ][ i ][ j ] = 0;
dp[ 0 ][ l ][ r ] = 1;
LL ret = 0;
for( int k = 1, cur = 1 ; k <= q ; k++, cur ^= 1 )
{
for( int i = l ; i <= r ; i++ )
{
ret = 0;
for( int j = r ; j >= i ; j-- )
dp[ cur ][ i ][ j ] = ret, ret += dp[ cur ^ 1 ][ i ][ j ] * ( n - j );
}
for( int j = l ; j <= r ; j++ )
{
ret = 0;
for( int i = l ; i <= j ; i++ )
dp[ cur ][ i ][ j ] += ret, ret += dp[ cur ^ 1 ][ i ][ j ] * ( i - 1 );
}
for( int i = l ; i <= r ; i++ )
for( int j = i ; j <= r ; j++ )
( dp[ cur ][ i ][ j ] += dp[ cur ^ 1 ][ i ][ j ] * ( cnt[ i - 1 ] + cnt[ j - i + 1 ] + cnt[ n - j ] ) ) %= mod;
}
for( int i = l ; i <= r ; i++ )
{
ret = 0;
for( int j = r ; j >= i ; j-- )
ret += dp[ q & 1 ][ i ][ j ], sum[ j ][ rk[ now ] ] += ret;
}
for( int i = l ; i <= r ; i++ ) sum[ i ][ rk[ now ] ] %= mod;
}
int main()
{
#ifdef wxh010910
freopen( "data.in", "r", stdin );
#endif
scanf( "%d%d", &n, &q );
for( int i = 1 ; i <= n ; i++ ) scanf( "%d", &x[ i ] ), y[ i ] = i, cnt[ i ] = i * ( i + 1 ) >> 1;
sort( y + 1, y + n + 1, cmp );
for( int i = 1 ; i <= n ; i++ ) rk[ y[ i ] ] = i;
for( int i = 1, l, r ; i <= n ; i++ )
{
l = r = i;
while( l && x[ l ] <= x[ i ] ) l--; while( r <= n && x[ r ] <= x[ i ] ) r++;
solve( l + 1, r - 1, i );
}
for( int i = 1 ; i <= n ; i++ )
{
int ans = 0;
for( int j = 1 ; j <= n ; j++ )
{
if( !sum[ i ][ j ] ) continue;
for( int k = 1 ; k < j ; k++ ) sum[ i ][ j ] = ( sum[ i ][ j ] - sum[ i ][ k ] + mod ) % mod;
ans = ( 1LL * x[ y[ j ] ] * sum[ i ][ j ] + ans ) % mod;
}
printf( "%d%c", ans, i == n ? '\n' : ' ' );
}
}