problem
solution
observation1 : \text{observation1}: observation1: 对于一个非空子段 [ l , r ] [l,r] [l,r],最后一个元素 a r a_r ar 一定不会被操作。
observation2 : \text{observation2}: observation2: 基于上一条进一步地有,对于一个非空子段,一段连续后缀满足 a i ≤ a i + 1 a_i\le a_{i+1} ai≤ai+1 一定不会被操作。
observation3 : \text{observation3}: observation3: 操作元素后产生的若干个数形成的 [ l , r ] [l,r] [l,r] 新序列,不仅满足不降,而且要开头尽可能的大。
observation4 : \text{observation4}: observation4: 基于第三点对于新拆出来的区间,其中每个数应该尽可能地平均。
我们假设一个非空子段中 a i > a i + 1 a_i>a_{i+1} ai>ai+1。
则至少要对 a i a_i ai 进行 ⌈ a i a i + 1 ⌉ − 1 \big\lceil\frac{a_i}{a_{i+1}}\big\rceil-1 ⌈ai+1ai⌉−1 次操作,拆分成 ⌈ a i a i + 1 ⌉ \big\lceil\frac{a_i}{a_{i+1}}\big\rceil ⌈ai+1ai⌉ 个数。
这么多个数的数值尽可能平均,所以最小的数值应为: ⌊ a i ⌈ a i a i + 1 ⌉ ⌋ \Big\lfloor\frac{a_i}{\big\lceil\frac{a_i}{a_{i+1}}\big\rceil}\Big\rfloor ⌊⌈ai+1ai⌉ai⌋。
设 f ( i , x ) : f(i,x): f(i,x): 满足 i ≤ j i\le j i≤j 的所有子段 [ i , j ] [i,j] [i,j] 操作使其不降后的第一个元素值为 x x x 的方案数。
即有多少个 i i i 开头的非空子段,贪心地操作后,形成的不降序列满足开头元素为 x x x。
暴力的转移有 f ( i , x ) = ∑ y f ( i + 1 , y ) f(i,x)=\sum_{y}f(i+1,y) f(i,x)=∑yf(i+1,y)。注意是通过枚举 y y y 计算出 x x x 的最优取值。
这个似乎是 O ( n 2 ) O(n^2) O(n2) 的欸?非也非也。
用 y y y 来计算最优取值 x = ⌊ a i ⌈ a i y ⌉ ⌋ x=\Big\lfloor\frac{a_i}{\big\lceil\frac{a_i}{y}\big\rceil}\Big\rfloor x=⌊⌈yai⌉ai⌋,即 f ( i + 1 , y ) → f ( i , ⌊ a i ⌈ a i y ⌉ ⌋ ) f(i+1,y)\rightarrow f(i,\Big\lfloor\frac{a_i}{\big\lceil\frac{a_i}{y}\big\rceil}\Big\rfloor) f(i+1,y)→f(i,⌊⌈yai⌉ai⌋)。
这很像数论分块的样子,经典根号范围。
所以这里只有 n \sqrt{n} n 种不同的取值,我们可以把开 vector \text{vector} vector 记录这些取值。
起到优化空间的作用,时间复杂度 O ( n n ) O(n\sqrt{n}) O(nn)。
至于统计答案,则是在 f ( i + 1 , y ) f(i+1,y) f(i+1,y) 转移时加上, f ( i + 1 , y ) ∗ i ∗ ( ⌊ a i a i + 1 ⌋ − 1 ) f(i+1,y)*i*(\big\lfloor\frac{a_i}{a_{i+1}}\big\rfloor-1) f(i+1,y)∗i∗(⌊ai+1ai⌋−1)。
乘上 i i i 是因为我们状态定义的是后缀,这段后缀可能存在 [ 1 , i ] [1,i] [1,i] 共 i i i 种开头的子段中,都要产生贡献。
最后注意过程中的数组清空问题。
code
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define mod 998244353
#define maxn 100005
vector < int > v[2];
int f[2][maxn];
int T, n, ans;
int a[maxn];
signed main() {
scanf( "%lld", &T );
while( T -- ) {
scanf( "%lld", &n );
for( int i = 1;i <= n;i ++ ) scanf( "%lld", &a[i] );
for( int i = n;i;i -- ) {
int k = i & 1, lst = a[i];
v[k].push_back( a[i] );
f[k][a[i]] = 1;
for( int y : v[k ^ 1] ) {
int t = (int)ceil( a[i] * 1.0 / y );
int x = a[i] / t;
f[k][x] += f[k ^ 1][y];
( ans += ( t - 1 ) * i % mod * f[k ^ 1][y] ) %= mod;
if( lst ^ x ) v[k].push_back( x ), lst = x;
}
for( int x : v[k ^ 1] ) f[k ^ 1][x] = 0;
v[k ^ 1].clear();
}
printf( "%lld\n", ans );
for( int i : v[0] ) f[0][i] = 0;
for( int i : v[1] ) f[1][i] = 0;
v[0].clear(), v[1].clear();
ans = 0;
}
return 0;
}