题意
一个序列的权值定义如下:假设该序列为 a 1 , . . . a k a_1,...a_k a1,...ak ,他的权值就是 ∀ 1 ≤ i ≤ k , m a x ( a 1 , a 2 , . . . a i ) \forall 1\le i\le k,max(a_1,a_2,...a_i) ∀1≤i≤k,max(a1,a2,...ai) 有多少种不同的取值。
现在有一个 1 1 1 到 n n n 的排列( n ≤ 1 0 5 n \le 10^5 n≤105 )。求这个排列的所有子序列的权值的 m m m 次方和( m ≤ 20 m\le 20 m≤20 )。
思路
关于 m m m 次方和,有这么一个公式:
n m = ∑ i = 1 min ( n , m ) S ( m , i ) ⋅ i ! ⋅ C n i n^m=\sum_{i=1}^{\min(n,m)}S(m, i)\cdot i!\cdot C_{n}^{i} nm=i=1∑min(n,m)S(m,i)⋅i!⋅Cni
其中 S ( m , i ) S(m,i) S(m,i) 是第二类斯特林数。
可以感性理解一下这个式子。左边表示 m m m 个不同的球放在 n n n 个不同的盒子里,盒子可以为空。右边枚举有几个盒子是空的,第二类斯特林数恰好求的又是 m m m 个球放在 i i i 个不同的盒子里,盒子不可以空的方案数。但是这个式子的推导与本题没有太大的关系。
然后考虑,假如所有子序列的权值放在一个可重集中, n k ∈ N n_k\in \N nk∈N 。答案就是
∑ n k ∈ N n k m = ∑ i = 1 min ( n , m ) S ( m , i ) ⋅ i ! ⋅ ∑ n k ∈ N C n k i \sum_{n_k\in \N}n_k^m=\sum_{i=1}^{\min(n,m)}S(m, i)\cdot i!\cdot \sum_{n_k\in \N}C_{n_k}^{i} nk∈N∑nkm=i=1∑min(n,m)S(m,i)⋅i!⋅nk∈N∑Cnki
然后我们并不需要直接计算 m m m 次方和,而是转而对每个 i i i 计算 ∑ n k ∈ N C n k i \sum_{n_k\in \N}C_{n_k}^{i} ∑nk∈NCnki 就好了。而组合数的递推是非常方便的。
那么我们考虑 DP ,设计 f i , j , h f_{i,j,h} fi,j,h 表示前 i i i 个数的子序列,组合数 C C C 上标那个数是 j j j ,目前子序列中最大值为 h h h 的所有子序列的 ∑ C n k i \sum C_{n_k}^{i} ∑Cnki 。
只要写出式子,就可以发现更新的时候操作并不多,可以用线段树维护,只需要写单点加,区间乘,区间查询就行了。
注意
注意常数。
用一个结构体存所有的 m m m 个 i , h i,h i,h 相同的 f i , j , h f_{i,j,h} fi,j,h 。这样的话使用线段树递归的次数会少 m m m 倍,而且乘法标记也不用重复打了。
众所周知递归是很慢的。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 20 + 1, mod = 1e9 + 7;
int n, m, a[N];
int f[M], s[M][M], fac[M], ans;
template<class T>inline void read(T &x){
x = 0; bool fl = 0; char c = getchar();
while (!isdigit(c)){if (c == '-') fl = 1; c = getchar();}
while (isdigit(c)){x = (x<<3)+(x<<1)+c-'0'; c = getchar();}
if (fl) x = -x;
}
inline int add(int &x, int y){x += y; if (x >= mod) x -= mod;}
inline int pls(int x, int y){x += y; return (x >= mod ? x-mod : x);}
struct node{
int c[M];
node(){memset(c, 0, sizeof c);}
node operator * (int x){
node ret;
for (int i = 0; i <= m; ++ i)
ret.c[i] = 1LL*c[i]*x%mod;
return ret;
}
node operator + (node u){
node ret;
for (int i = 0; i <= m; ++ i)
ret.c[i] = pls(c[i], u.c[i]);
return ret;
}
}t[N<<2];
int laz[N<<2];
#define ls (u<<1)
#define rs (u<<1^1)
inline void push_up(int u){
t[u] = t[ls]+t[rs];
}
inline void push_down(int u){
if (laz[u] > 1){
t[ls] = t[ls]*laz[u];
laz[ls] = 1LL*laz[ls]*laz[u]%mod;
t[rs] = t[rs]*laz[u];
laz[rs] = 1LL*laz[rs]*laz[u]%mod;
laz[u] = 1;
}
}
void build(int u, int l, int r){
laz[u] = 1;
if (l == r){
if (l == 0) t[u].c[0] = 1;
return;
}
int mid = l+r>>1;
build(ls, l, mid);
build(rs, mid+1, r);
push_up(u);
}
void modify_mul(int u, int l, int r, int L, int R, int X){
if (L <= l && r <= R){
t[u] = t[u]*X;
laz[u] = 1LL*laz[u]*X%mod;
return;
}
int mid = l+r>>1;
push_down(u);
if (L <= mid) modify_mul(ls, l, mid, L, R, X);
if (mid < R) modify_mul(rs, mid+1, r, L, R, X);
push_up(u);
}
void modify_add(int u, int l, int r, int P, node X){
if (l == r){
for (int i = 1; i <= m; ++ i)
add(t[u].c[i], X.c[i]+X.c[i-1]);
add(t[u].c[0], X.c[0]);
return;
}
int mid = l+r>>1;
push_down(u);
if (P <= mid) modify_add(ls, l, mid, P, X);
else modify_add(rs, mid+1, r, P, X);
push_up(u);
}
node query(int u, int l, int r, int L, int R){
if (L <= l && r <= R) return t[u];
int mid = l+r>>1, ret = 0;
push_down(u);
if (mid < L) return query(rs, mid+1, r, L, R);
else if (R <= mid) return query(ls, l, mid, L, R);
else return query(ls, l, mid, L, R)+query(rs, mid+1, r, L, R);
}
void solve_c(){
build(1, 0, n);
for (int i = 1; i <= n; ++ i){
modify_add(1, 0, n, a[i], query(1, 0, n, 0, a[i]));
if (a[i] < n) modify_mul(1, 0, n, a[i]+1, n, 2);
}
}
void solve_ans(){
node tmp = query(1, 0, n, 1, n);
for (int i = 1; i <= m; ++ i)
f[i] = tmp.c[i];
s[1][1] = 1;
for (int i = 2; i <= m; ++ i)
for (int j = 1; j <= i; ++ j)
s[i][j] = pls(s[i-1][j-1], 1LL*s[i-1][j]*j%mod);
fac[0] = 1;
for (int i = 1; i <= m; ++ i)
fac[i] = 1LL*fac[i-1]*i%mod;
for (int i = 1; i <= m; ++ i)
add(ans, 1LL*s[m][i]*fac[i]%mod*f[i]%mod);
}
int main()
{
read(n); read(m);
for (int i = 1; i <= n; ++ i)
read(a[i]);
solve_c();
solve_ans();
printf("%d\n", ans);
return 0;
}