题目大意:
给出一个长度为n的全排列
p
p
p,
设
m
i
d
l
,
r
mid_{l,r}
midl,r为全排列中区间
[
l
,
r
]
[l,r]
[l,r]的中位数,
求
2
∗
∑
l
=
1
n
∑
r
=
l
n
s
e
e
d
(
l
−
1
)
∗
n
+
r
m
i
d
l
,
r
2*\sum_{l=1}^{n}\sum_{r=l}^{n}seed^{(l-1)*n+r}mid_{l,r}
2∗∑l=1n∑r=lnseed(l−1)∗n+rmidl,r。
答案对
1
e
9
+
7
1e9+7
1e9+7取模。
n
≤
1
e
4
,
s
e
e
d
<
1
e
9
+
7
n≤1e4,seed<1e9+7
n≤1e4,seed<1e9+7
分析:
提前处理出
s
e
e
d
n
seed^n
seedn,设为
s
e
e
d
n
u
m
seednum
seednum
枚举右端点
r
r
r,
r
r
r为排列
p
p
p中某个位置
对
[
1
,
r
]
[1,r]
[1,r]建立链表S,不是
P
[
l
.
.
r
]
P[l..r]
P[l..r],是数
[
l
,
r
]
[l,r]
[l,r],
S
r
S_{r}
Sr可以由
S
r
+
1
S_{r+1}
Sr+1删掉数
p
r
+
1
p_{r+1}
pr+1
令
T
T
T=
S
r
S_{r}
Sr
然后从左到右的枚举左端点
l
l
l,
每次用链表维护一下中位数就可以,因为可以数列长度为偶数,那么我们令
m
i
d
1
,
m
i
d
2
mid_1,mid_2
mid1,mid2表示2个中位数分别的位置,如果数列长度为奇数,则
m
i
d
1
=
m
i
d
2
mid1=mid2
mid1=mid2。
每次
l
+
1
l+1
l+1都将
T
T
T中的
p
l
p_{l}
pl删掉,
s
e
e
d
(
l
−
1
)
∗
n
+
r
seed^{(l-1)*n+r}
seed(l−1)∗n+r可以通过
s
e
e
d
n
seed^n
seedn的预处理在每次
l
+
1
l+1
l+1时O(1)计算,每次维护链表都是
O
(
1
)
O(1)
O(1)的
然后这题就搞完了
时间复杂度是
O
(
n
2
)
O(n^2)
O(n2)
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstring>
#include <algorithm>
#define N 10005
using namespace std;
typedef long long ll;
const int mo = 1e9 + 7;
struct Node {
int pre[N], nxt[N];
void del(int x)
{
nxt[pre[x]] = nxt[x];
pre[nxt[x]] = pre[x];
}
}S, T;
int p[N], seednum, seed, ans, n;
int power(int x, int y)
{
int rp = 1;
for (; y; y >>= 1)
{
if (y & 1) rp = (ll)rp * x % mo;
x = (ll)x * x % mo;
}
return rp;
}
int main()
{
scanf("%d %d", &n, &seed);
for (int i = 1; i <= n; i++) scanf("%d", &p[i]);
seednum = power(seed, n);
for(int i = 1; i <= n + 1; i++) S.pre[i] = i - 1, S.nxt[i - 1] = i;
int pos1 = (n + 1) / 2, pos2 = (n + 2) / 2;
for (int i = n; i >= 1; i--)
{
int posA = pos1, posB = pos2;
T = S;
int now = power(seed, i);
for (int j = 1; j <= i; j++)
{
ans = (ans + (ll)now * (posA + posB) % mo) % mo;
if (posA == posB)
{
if (p[j] <= posB) posB = T.nxt[posB];
if (p[j] >= posA) posA = T.pre[posA];
}
else
{
if (p[j] <= posA) posA = T.nxt[posA];
if (p[j] >= posB) posB = T.pre[posB];
}
T.del(p[j]);
now = (ll)now * seednum % mo;
}
if (pos1 == pos2)
{
if (p[i] <= pos2) pos2 = S.nxt[pos2];
if (p[i] >= pos1) pos1 = S.pre[pos1];
}
else
{
if (p[i] <= pos1) pos1 = S.nxt[pos1];
if (p[i] >= pos2) pos2 = S.pre[pos2];
}
S.del(p[i]);
}
printf("%d\n", ans);
return 0;
}