题目:
给定一个长度为长度为
2
n
2n
2n 的数组
a
a
a ,现在将数组
a
a
a分裂成两个长度为
n
n
n 的数组
p
p
p 和数组
q
q
q,将数组
p
p
p 按从小到大排序得到
x
x
x,将数组
q
q
q 按从大到小排序,得到数组
y
y
y,定义
f
(
p
,
q
)
=
∑
i
=
1
n
∣
x
i
−
y
i
∣
f(p,q)=\sum_{i=1}^n |x_i-y_i|
f(p,q)=∑i=1n∣xi−yi∣,问所有对
a
a
a的合法划分的
f
(
p
,
q
)
f(p,q)
f(p,q)的和是多少。
(
1
≤
n
≤
150000
,
1
≤
a
i
≤
1
0
9
)
(1 \le n \le 150000,1 \le a_i \le 10^9)
(1≤n≤150000,1≤ai≤109)
题解:
由于划分之后需要排序,所以在划分前进行排序不影响结果,我们令
b
=
b=
b=从小到大排序后的
a
a
a。
结论:对于所有的合法划分,
f
(
p
,
q
)
f(p,q)
f(p,q) 的值都是相同的。
证明:令集合
L
=
{
x
∣
x
∈
b
1...
n
}
L=\{ x|x \in b_{1...n}\}
L={x∣x∈b1...n},
R
=
{
x
∣
x
∈
b
n
+
1...2
n
}
R=\{x|x \in b_{n+1...2n} \}
R={x∣x∈bn+1...2n}。现在我们来证明对于任意的
∣
x
i
−
y
i
∣
|x_i-y_i|
∣xi−yi∣都对应于一个
R
R
R中的元素减去一个
L
L
L中的元素。利用反证法,如果不满足,不妨设
∃
x
i
,
y
i
∈
L
,
x
i
<
y
i
\exist x_i,y_i \in L,x_i <y_i
∃xi,yi∈L,xi<yi,那么
y
i
>
x
1...
i
y_i>x_{1...i}
yi>x1...i且
y
i
>
y
i
+
1...
n
y_i>y_{i+1...n}
yi>yi+1...n,那么
y
i
y_i
yi至少比
n
n
n个数大,则
y
i
y_i
yi不可能在
L
L
L中,矛盾,其他情况也可以类似证明,所以对于任意的
∣
x
i
−
y
i
∣
|x_i-y_i|
∣xi−yi∣都对应于一个
R
R
R中的元素减去一个
L
L
L中的元素,所以上述结论成立。
所有的合法划分的个数为
C
2
n
n
C_{2n}^n
C2nn,所以最后的答案为
C
2
n
n
×
∑
i
=
1
n
b
i
+
n
−
b
i
C_{2n}^n \times \sum_{i=1}^n b_{i+n}-b_i
C2nn×i=1∑nbi+n−bi
复杂度: O ( n l o g n ) O(nlogn) O(nlogn)
代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<string>
#include<bitset>
#include<sstream>
#include<ctime>
//#include<chrono>
//#include<random>
//#include<unordered_map>
using namespace std;
#define ll long long
#define ls o<<1
#define rs o<<1|1
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define sz(x) (int)(x).size()
#define all(x) (x).begin(),(x).end()
const double pi=acos(-1.0);
const double eps=1e-6;
const int mod=998244353;
const int INF=0x3f3f3f3f;
const int maxn=3e5+5;
int n;
int a[maxn];
ll qpow(ll a,ll p=mod-2){
ll res=1;
while(p){
if(p&1)res=res*a%mod;
a=a*a%mod;
p>>=1;
}
return res;
}
int main(void){
scanf("%d",&n);
for(int i=1;i<=2*n;i++){
scanf("%d",&a[i]);
}
sort(a+1,a+2*n+1);
ll sum=0;
for(int i=1;i<=n;i++){
sum-=a[i];
}
for(int i=n+1;i<=2*n;i++){
sum+=a[i];
}
sum%=mod;
ll num=1;
for(int i=n+1;i<=2*n;i++){
num=num*i%mod;
}
ll tmp=1;
for(int i=1;i<=n;i++){
tmp=tmp*i%mod;
}
tmp=qpow(tmp);
num=num*tmp%mod;
ll ans=sum*num%mod;
printf("%lld\n",ans);
return 0;
}