题意:
数据范围:n<=2e5
解法:
∑
i
=
1
n
∣
p
(
i
)
−
q
(
i
)
∣
=
∑
i
=
1
n
m
a
x
(
p
(
i
)
,
q
(
i
)
)
−
m
i
n
(
p
(
i
)
,
q
(
i
)
)
m
a
x
(
p
i
,
q
i
)
一
定
是
前
n
大
的
数
,
m
i
n
(
q
,
p
i
)
一
定
是
前
n
小
的
数
.
如
果
不
是
这
样
:
1.
某
个
m
a
x
(
p
i
,
q
i
)
和
m
i
n
(
p
i
,
q
i
)
都
是
前
n
大
,
那
么
会
导
致
某
个
m
a
x
(
p
j
,
q
j
)
和
m
i
n
(
p
j
,
q
j
)
都
是
前
n
小
,
这
样
的
匹
配
在
p
非
递
减
,
q
非
递
增
的
排
序
方
式
下
是
不
存
在
的
.
2.
某
个
m
a
x
(
p
i
,
q
i
)
和
m
i
n
(
p
i
,
q
i
)
都
是
前
n
小
,
那
么
会
导
致
某
个
m
a
x
(
p
j
,
q
j
)
和
m
i
n
(
p
j
,
q
j
)
都
是
前
n
大
.
这
样
的
匹
配
在
p
非
递
减
,
q
非
递
增
的
排
序
方
式
下
是
不
存
在
的
.
因
此
无
论
怎
么
排
列
,
f
函
数
的
值
=
前
n
大
−
前
n
小
.
总
排
列
数
为
C
(
2
n
,
n
)
,
乘
上
f
函
数
的
值
就
是
答
案
.
\sum_{i=1}^n|p(i)-q(i)|=\sum_{i=1}^nmax(p(i),q(i))-min(p(i),q(i))\\ max(pi,qi)一定是前n大的数,min(q,pi)一定是前n小的数.\\ 如果不是这样:\\ 1.某个max(pi,qi)和min(pi,qi)都是前n大,那么会导致某个max(pj,qj)和min(pj,qj)都是前n小,\\ 这样的匹配在p非递减,q非递增的排序方式下是不存在的.\\ 2.某个max(pi,qi)和min(pi,qi)都是前n小,那么会导致某个max(pj,qj)和min(pj,qj)都是前n大.\\ 这样的匹配在p非递减,q非递增的排序方式下是不存在的.\\ 因此无论怎么排列,f函数的值=前n大-前n小.\\ 总排列数为C(2n,n),乘上f函数的值就是答案.
i=1∑n∣p(i)−q(i)∣=i=1∑nmax(p(i),q(i))−min(p(i),q(i))max(pi,qi)一定是前n大的数,min(q,pi)一定是前n小的数.如果不是这样:1.某个max(pi,qi)和min(pi,qi)都是前n大,那么会导致某个max(pj,qj)和min(pj,qj)都是前n小,这样的匹配在p非递减,q非递增的排序方式下是不存在的.2.某个max(pi,qi)和min(pi,qi)都是前n小,那么会导致某个max(pj,qj)和min(pj,qj)都是前n大.这样的匹配在p非递减,q非递增的排序方式下是不存在的.因此无论怎么排列,f函数的值=前n大−前n小.总排列数为C(2n,n),乘上f函数的值就是答案.
解释一下为什么不存在:
假设上图中的绿点是两个非法匹配点,假设是两个前n大,
那么根据排序规则,p中的绿点一定在右边,q中的绿点一定在左边,
因为前n大只有n个,所以这两个点一定不会对应上,即不存在这样的匹配。
code:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxm=1e6+5;
const int mod=998244353;
int fac[maxm],inv[maxm];
int a[maxm];
int n;
int ppow(int a,int b,int mod){
int ans=1%mod;a%=mod;
for(;b;b>>=1,a=a*a%mod)if(b&1)ans=ans*a%mod;
return ans;
}
void init(){
fac[0]=1;
for(int i=1;i<maxm;i++)fac[i]=fac[i-1]*i%mod;
inv[maxm-1]=ppow(fac[maxm-1],mod-2,mod);
for(int i=maxm-2;i>=0;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int C(int n,int m){
if(m<0||m>n)return 0;
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
ios::sync_with_stdio(0);
init();
cin>>n;
for(int i=1;i<=n*2;i++)cin>>a[i];
sort(a+1,a+1+n*2);
int mis=0,mas=0;
for(int i=1;i<=n;i++){
mis=(mis+a[i])%mod;
}
for(int i=n+1;i<=n*2;i++){
mas=(mas+a[i])%mod;
}
int ans=C(n*2,n)*(mas-mis)%mod;
ans=(ans%mod+mod)%mod;
cout<<ans<<endl;
return 0;
}