题面
JZM 有一个绝妙的想法,记作长度为 n 的序列 A,可惜被 SY 翻转了其中一个区间,得到了序列 B。
给定序列 B,在可能的众多原序列 A 中,JZM 需要你求出最小字典序的序列 A。但是你不需要输出整个序列 A,你只需要输出这个序列的权值。记一个序列的权值为 ∑ i = 1 n A i ⋅ 11451 4 n − i ∑^n_{i=1} A_i·114514^{n−i} ∑i=1nAi⋅114514n−i 对 998244353 取模的结果。
1 ≤ n ≤ 5 × 1 0 6 , 0 < B i < 998244353 1 ≤ n ≤ 5 × 10^6,0 < B_i < 998244353 1≤n≤5×106,0<Bi<998244353
题解
真心不难,但就是没做出来 😭
首先的第一步推导:从左边开始找,遇到第一个值大于后缀最小值的位置,那么这个位置一定是要翻转的区间的左端点。这很容易证明,反证法即可,考试时也想到了的。
然后在后面决定一个右端点,使得答案字典序最小。我当时竟然全去想后缀数组了,没有意识到该问题的特殊性:
- 后缀数组:在一个串内的规则子串(后缀),把它们排序
- 此题:在一个串内的不规则子串,求最值
求最值,且我们只求一次,那不直接遍历一遍找?就算是不规则的两个子串(翻转后子串后半段为先前串头后面的原串部分,该比较并不是在常规字串中的),也可以用最初进行字符串排序的算法二分+哈希啊,这样一来时间复杂度 O ( n log n ) O(n\log n) O(nlogn) ,可以卡过。
CODE
我用的双哈希,特判了一下,使其在随机数据下期望 O ( n ) O(n) O(n) 。
#include<set>
#include<map>
#include<stack>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 500005
#define DB double
#define LL long long
#define ENDL putchar('\n')
#define lowbit(x) (-(x) & (x))
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
const int MOD = 998244353;
int qkpow(int a,int b,int MOD) {
int res = 1;
while(b > 0) {
if(b & 1) res = res *1ll* a % MOD;
a = a *1ll* a % MOD; b >>= 1;
}return res;
}
inline void MD(int &x) {if(x>=MOD)x-=MOD;}
inline int MD(int x,int MOD) {return x>=MOD? (x-MOD):x;}
int n,m,i,j,s,o,k;
int a[MAXN],mi[MAXN],le[MAXN];
struct it{
int ad,nm,le;it(){ad=0;nm=MOD;le=0;}
it(int A,int N,int L){ad=A;nm=N;le=L;}
};
bool operator < (it a,it b) {
if(a.nm == b.nm) return a.le > b.le;
return a.nm < b.nm;
}
bool operator > (it a,it b) {
return b < a;
}
bool operator == (it a,it b) {return a.nm == b.nm && a.le == b.le;}
it b[MAXN];int cnb;
void Push(it x) {
if(cnb == 0) {b[++ cnb] = x;return ;}
if(x < b[1]) {b[cnb = 1] = x;return ;}
if(x == b[1]) {b[++ cnb] = x;return ;}
return ;
}
stack<it> st;
int Ld;
it getit(int ad,int le) {
int ad2 = ad-le+1;
if(ad2 < Ld) ad2 = ad + (Ld-ad2);
return it(ad,a[ad2],le);
}
const int MOD1 = 998444353,MOD2 = 1000011007;
struct np{
int n1,n2;np(){n1=n2=0;}
np(int A,int B){n1=A;n2=B;}
np(int x){n1 = x % MOD1;n2 = x % MOD2;}
};
np operator + (np a,np b) {return np(MD(a.n1+b.n1,MOD1),MD(a.n2+b.n2,MOD2));}
np operator + (np a,int b) {return np(MD(a.n1+b,MOD1),MD(a.n2+b,MOD2));}
np operator - (np a,np b) {return np(MD(a.n1+MOD1-b.n1,MOD1),MD(a.n2+MOD2-b.n2,MOD2));}
np operator - (np a,int b) {return np(MD(a.n1+MOD1-b,MOD1),MD(a.n2+MOD2-b,MOD2));}
np operator * (np a,np b) {return np(a.n1*1ll*b.n1%MOD1,a.n2*1ll*b.n2%MOD2);}
np operator * (np a,int b) {return np(a.n1*1ll*b%MOD1,a.n2*1ll*b%MOD2);}
bool operator == (np a,np b) {return a.n1 == b.n1 && a.n2 == b.n2;}
bool operator != (np a,np b) {return a.n1 != b.n1 || a.n2 != b.n2;}
np pre[MAXN],suf[MAXN<<1],po[MAXN],invp[MAXN];
np query(int ad,int le) {
int ad2 = max(Ld-1,ad-le),le2 = ad-ad2;
np nm1 = pre[ad] - (pre[ad2]*po[le2]);
if(le2 < le) {
np nm2 = suf[ad+1] - (suf[ad+le-le2+1]*po[le-le2]);
nm1 = nm1 + (nm2 * po[le2]);
}return nm1;
}
bool CMP(int a,int b) {
int l = 1,r = n,mid;
while(l < r) {
mid = (l + r) >> 1;
if(query(a,mid) != query(b,mid)) r = mid;
else l = mid+1;
}
return getit(a,l) < getit(b,l);
}
int main() {
freopen("thoughts.in","r",stdin);
freopen("thoughts.out","w",stdout);
n = read();
for(int i = 1;i <= n;i ++) a[i] = read();
mi[n+1] = MOD;
for(int i = n;i > 0;i --) mi[i] = min(mi[i+1],a[i]);
pre[0] = suf[n+1] = np(0);po[0] = invp[0] = np(1);
np INVP = np(qkpow(MOD,MOD1-2,MOD1),qkpow(MOD,MOD2-2,MOD2));
for(int i = 1;i <= n;i ++) {
pre[i] = pre[i-1] * MOD + a[i];
po[i] = po[i-1] * MOD;
invp[i] = invp[i-1] * INVP;
}
for(int i = n;i > 0;i --) {
suf[i] = suf[i+1] * MOD + a[i];
}
Ld = 0;
for(int i = 1;i <= n;i ++) {
if(a[i] > mi[i]) {
Ld = i; break;
}
}
if(Ld) {
le[Ld-1] = 0; cnb = 0;
for(int i = Ld;i <= n;i ++) {
if(a[i] == a[i-1]) le[i] = le[i-1]+1;
else le[i] = 1;
Push(it(i,a[i],le[i]));
}
int Pos = b[1].ad;
for(int i = 2;i <= cnb;i ++) {
if(CMP(b[i].ad,Pos)) {
Pos = b[i].ad;
}
}
for(int i = Ld,j = Pos;i < j;i ++,j --) {
swap(a[i],a[j]);
}
}
int ans = 0,pw = 1;
for(int i = n;i > 0;i --) {
MD(ans += a[i] *1ll* pw % MOD);
pw = pw * 114514ll % MOD;
}
printf("%d\n",ans);
return 0;
}