D. Yet Another Inversions Problem
题意
给定正整数
n
n
n 和
k
k
k,并分别给出一个长度为
n
n
n 的奇排列
p
p
p 和 一个长度为
k
k
k 的
0
0
0排列
q
q
q
按照题中给出的方式构造出数组
a
a
a,求出
a
a
a 中的逆序对数量
思路
考虑将 a a a 分解成 n n n 个长度为 k k k 的子数组,那么可以发现这些子数组内部的逆序对数量等于 q q q 中原先的逆序对数量,因为 p i p_i pi 固定,只有 q j q_j qj 在变化。我们就可以先用树状数组求出 q q q 中原来的逆序对数量,乘上 n n n 先加到答案中。
那么现在考虑前后两个子数组的连接,会生成多少个逆序对:
由于在前面已经考虑了每个长度为
k
k
k 的子数组内部的逆序对,所以这里可以将前后两个要连接的子数组先排列好:
[
x
⋅
2
0
,
x
⋅
2
1
,
x
⋅
2
2
,
.
.
.
,
x
⋅
2
k
−
1
]
[x \cdot 2^0,x \cdot 2^1,x \cdot 2^2,...,x \cdot 2^{k-1}]
[x⋅20,x⋅21,x⋅22,...,x⋅2k−1],
[
y
⋅
2
0
,
y
⋅
2
1
,
y
⋅
2
2
,
.
.
.
,
y
⋅
2
k
−
1
]
[y \cdot 2^0,y \cdot 2^1,y \cdot 2^2,...,y \cdot 2^{k-1}]
[y⋅20,y⋅21,y⋅22,...,y⋅2k−1],类似于归并排序求解逆序对的方法。
现在连接这两个子数组,假设
x
<
y
x < y
x<y,假设
y
⋅
2
0
y \cdot 2^0
y⋅20 前面要放的子数组
1
1
1 的元素有
z
z
z 个,那么
x
⋅
2
z
<
y
x \cdot 2^z < y
x⋅2z<y,也就是说
z
z
z 是满足这个关系式的最大的整数(关系式无法取到等号,因为
x
≠
y
x \neq y
x=y 且
x
、
y
x、y
x、y 均为奇数)。那么有:
x
⋅
2
z
<
y
<
x
⋅
2
z
+
1
x \cdot 2^z < y < x \cdot 2^{z+1}
x⋅2z<y<x⋅2z+1,也就是说:在前面放了
z
+
1
z+1
z+1 个数组
1
1
1 的元素后,后面数组
1
1
1 和数组
2
2
2 的元素一定交替出现,并且最后一定是以
z
+
1
z+1
z+1 个数组
2
2
2 的元素结尾。例如官方题解的解释:
根据上述关系式,可以求出:
z
=
log
2
y
x
z = \log_2 \dfrac{y}{x}
z=log2xy,那么对于一个当前的
y
=
p
i
y = p_i
y=pi,只要前面
[
1
,
i
−
1
]
[1,i-1]
[1,i−1] 的
x
x
x 与
y
y
y 的大小关系根据上式求出的
z
z
z 一样,那么这些
x
x
x 所属的子数组与当前
y
y
y 所属的子数组连接时,排列方式是一样的。又因为
y
≤
2
⋅
n
−
1
y \leq 2\cdot n - 1
y≤2⋅n−1,
x
≥
1
x \geq 1
x≥1,
故
y
x
≤
2
⋅
n
⇒
log
2
y
x
≤
log
2
(
2
⋅
n
)
\dfrac{y}{x} \leq 2\cdot n \Rightarrow \log_2 \dfrac{y}{x} \leq \log_2 (2\cdot n)
xy≤2⋅n⇒log2xy≤log2(2⋅n),也就是说:按照这样子的方式来划分连接方式的话,最多只有
log
2
(
2
⋅
n
)
\log_2 (2 \cdot n)
log2(2⋅n) 种连接方式!
进一步观察不难发现:同一种连接方式所产生的逆序对数量是一样的,并且符合某种等差数列的变化方式。
- T i p s Tips Tips:我们可以预处理首项和公差为 1 1 1 的前缀和数组 s u m sum sum,加速计算。
依旧是假设
x
<
y
x < y
x<y,取
z
=
⌊
log
2
y
x
⌋
z = \lfloor \log_2 \dfrac{y}{x} \rfloor
z=⌊log2xy⌋,
y
⋅
2
0
y \cdot 2^0
y⋅20 前面已经出现了
z
+
1
z + 1
z+1 个
x
x
x 子数组里的元素,故其后面只剩下
k
−
z
−
1
k - z - 1
k−z−1 个
x
x
x 子数组里的元素了,每往后一个
y
y
y 的元素,其产生的逆序对数量减一,最后
z
+
1
z + 1
z+1 个
y
y
y 元素不产生逆序对。总的逆序对数量是:
(
k
−
z
−
1
)
+
(
k
−
z
−
2
)
+
(
k
−
z
−
3
)
+
.
.
.
+
1
=
s
u
m
[
k
−
z
−
1
]
(k-z-1) + (k-z-2) + (k-z-3) + ... + 1 = sum[k-z-1]
(k−z−1)+(k−z−2)+(k−z−3)+...+1=sum[k−z−1]。
但是当
k
≤
z
+
1
k \leq z + 1
k≤z+1 时,这种
x
<
y
x < y
x<y 的情况由于
x
x
x 子数组的元素全部在
y
y
y 的前面,因此产生的逆序对数量是
0
0
0。
对于
x
>
y
x > y
x>y 的情况也是类似:不过把
x
x
x 和
y
y
y 的位置互换一下,现在现在前面放
z
+
1
z + 1
z+1 个
y
y
y 子数组的元素,然后
x
x
x、
y
y
y 交替放,最后
z
+
1
z + 1
z+1 个
x
x
x 子数组的元素。最前面
z
+
1
z+1
z+1 个
y
y
y 元素产生的逆序对数量是:
(
z
+
1
)
⋅
k
(z+1) \cdot k
(z+1)⋅k,后面的情况与前面类似,不断递减,但是最后可能无法减到
1
1
1,因为最后
z
+
1
z+1
z+1 个
x
x
x 元素连续放在最后,所以最后只能减到
z
+
1
z+1
z+1。总的逆序对数量就是:
z
⋅
k
+
k
+
(
k
−
1
)
+
(
k
−
2
)
+
.
.
.
+
(
z
+
1
)
=
z
⋅
k
+
s
u
m
[
k
]
−
s
u
m
[
z
]
z \cdot k + k + (k-1) + (k-2) +...+ (z+1) = z \cdot k + sum[k] - sum[z]
z⋅k+k+(k−1)+(k−2)+...+(z+1)=z⋅k+sum[k]−sum[z]。
同样的道理,当
z
+
1
≥
k
z + 1 \geq k
z+1≥k 时,所有的
y
y
y 元素都放在最前面的连续
k
k
k 个,后面连续放
x
x
x 数组的元素,这种情况产生的逆序对数量是:
k
2
k^2
k2
讲到这里就已经大概知道怎么写了,用线段树单点修改,区间查询来维护某个区间出现了多少个
p
i
p_i
pi ,对于当前的
y
=
p
i
y = p_i
y=pi ,我们先枚举
z
∈
[
0
,
log
2
2
n
]
z \in [0,\log_2 2n]
z∈[0,log22n],然后分别对情况
1
1
1 和情况
2
2
2 的所有符合条件的
x
x
x 分别计算:
对于当前的
z
z
z ,所有小于
y
y
y 的
x
∈
[
y
2
z
+
1
,
y
2
z
]
x \in [\dfrac{y}{2^{z+1}},\dfrac{y}{2^z}]
x∈[2z+1y,2zy],所有大于
y
y
y 的
x
∈
[
2
z
⋅
y
,
2
z
+
1
⋅
y
]
x \in [2^z \cdot y,2^{z+1} \cdot y]
x∈[2z⋅y,2z+1⋅y]。由于
∀
p
i
为奇数
\forall p_i 为奇数
∀pi为奇数,所以这里的区间的边界可以忽略。利用线段树查询出这些区间的
x
x
x 数量,乘上相应的权值加到答案上即可。
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define lowbit(x) ((x) & -(x))
const int INF=0x3f3f3f3e;
const long long INFLL=0x3f3f3f3f3f3f4000LL;
typedef long long ll;
const int N = 200050;
ll cnt[N<<3];
int p[N], q[N];
const ll mod = 998244353;
ll sum[N];
ll LOG2[N<<1];
std::vector<ll> fen;
ll n,k;
ll P2[N];
void Fen_update(int x,int d){
while(x <= k){
fen[x] += d;
x += lowbit(x);
}
}
ll Fen_query(int x){
ll res = 0;
while(x > 0){
res += fen[x];
x -= lowbit(x);
}
return res;
}
void build(int p,int l,int r){
cnt[p] = 0;
if(l == r){
return;
}
int mid = l + r >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
}
void update(int p,int l,int r,int qp){
if(l == r){
++cnt[p];
return;
}
int mid = l + r >> 1;
if(qp <= mid) update(p<<1,l,mid,qp);
else update(p<<1|1,mid+1,r,qp);
cnt[p] = cnt[p<<1] + cnt[p<<1|1];
}
ll query(int p,int l,int r,int ql,int qr){
if(ql <= l && r <= qr) return cnt[p];
if(r < ql || qr < l) return 0;
int mid = l + r >> 1;
ll res = 0;
if(ql <= mid) res = query(p<<1,l,mid,ql,qr);
if(qr > mid) res += query(p<<1|1,mid+1,r,ql,qr);
return res;
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
fore(i,1,N) sum[i] = (sum[i-1] + i)%mod;
LOG2[0] = -1;
fore(i,1,N<<1) LOG2[i] = LOG2[i>>1] + 1;
P2[0] = 1;
fore(i,1,N) P2[i] = P2[i-1] * 2ll % mod;
int t;
std::cin>>t;
while(t--){
std::cin>>n>>k;
build(1,1,2*n);
fore(i,1,n+1) std::cin>>p[i];
fore(i,1,k+1) std::cin>>q[i];
ll ans = 0;
fen.assign(k+10,0);
fore(i,1,k+1){
ans = (ans + Fen_query(k) - Fen_query(q[i]+1) + mod)%mod;
Fen_update(q[i]+1,1);
}
ans = ans*n % mod;
update(1,1,2*n,p[1]);
fore(i,2,n+1){
fore(z,0,LOG2[2*n] + 1){
int y = p[i];
if(k - z - 1 > 0) ans = (ans + query(1,1,2*n,y/P2[z+1]+1,y/P2[z]) * sum[k - z - 1] %mod) % mod;
if(P2[z] * y >= 2*n) continue; //这里没有爆long long,但是传入query的时候爆int了
if(z + 1 <= k)ans = (ans + query(1,1,2*n,P2[z]*y,P2[z+1]*y) * (((k*z%mod + sum[k])%mod - sum[z] + mod)%mod)%mod) % mod;
else ans = (ans + query(1,1,2*n,P2[z]*y,P2[z+1]*y)*k%mod*k%mod)%mod;
}
update(1,1,2*n,p[i]);
}
std::cout<<ans<<endl;
}
return 0;
}