根据期望的线性性,我们只需对于每两个位置 (i,j) ( i , j ) 计算出其相对位置改变的概率,并根据 ai a i 和 aj a j 的大小关系统计贡献即可。
于是我们不难得到一个 O(n2k) O ( n 2 k ) 的DP。设 fi,j,k f i , j , k 表示当前 (i,j) ( i , j ) 两个位置再进行 k k 步交换操作使得的概率(边界就是 fi,j,0=[i>j] f i , j , 0 = [ i > j ] ),显然可以从 fi,j,k−1,fj,i,k−1,∑t≠jfi,t,k−1,∑t≠ift,j,k−1 f i , j , k − 1 , f j , i , k − 1 , ∑ t ≠ j f i , t , k − 1 , ∑ t ≠ i f t , j , k − 1 乘上其对应的概率转移过来,后两个和式先处理出来就可以 O(1) O ( 1 ) 转移了。
考虑优化这个DP。加入我们把两个位置
(i,j)
(
i
,
j
)
看成一个去除了直线
x=y
x
=
y
二维平面上的一个点,那么问题就转化成对于
x=y
x
=
y
上方的每一个点,求出
x=y
x
=
y
下方的所有点转移
k
k
步之后能到达它的概率之和,其中转移分别是:
1. 不动,。
2. 移动到当前行的另一个点,
P=n−2
P
=
n
−
2
。
3. 移动到当前列的另一个点,
P=n−2
P
=
n
−
2
。
4. 移动到关于
x=y
x
=
y
的对称点,
P=1
P
=
1
。
其中
z=n(n−1)2
z
=
n
(
n
−
1
)
2
。通过观察发现,平面上所有点经过若干次操作走到给定点
(i,j)
(
i
,
j
)
的概率只有本质不同的
5
5
种。具体如下:
其中是给定点
(i,j)
(
i
,
j
)
,
2
2
为其对称点,为与
(i,j)
(
i
,
j
)
同行同列的其它点,
4
4
为与同行同列的其它点,
0
0
为除掉以上四种剩下的点。注意到之间其实是没有交的,因为直线
x=y
x
=
y
已经被除去了。
那么这
5
5
种情况相互之间的转移矩阵就是:
矩阵快速幂搞出所有系数。有了系数我们就可以 n2 n 2 枚举 (i,j) ( i , j ) ,统计在 x=y x = y 下方每种点的个数即可。但我们观察到: 0 0 有个; 1 1 有个; 2 2 有个; 3 3 有个; 4 4 有个。都只和 j−i j − i 有关,于是我们可以用树状数组分别统计出顺序对/逆序对的对数和 ∑(j−i) ∑ ( j − i ) ,就可以一起算了。总复杂度 O(nlogn+125logk) O ( n log n + 125 log k ) 。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define N 500010
#define up(x,y) (x=(x+(y))%mod)
using namespace std;
const int mod=998244353;
int n,K,a[N];
ll z,iz,ans,cnt[5];
int read()
{
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=-1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
ll ksm(ll a,ll b)
{
ll r=1;
for(;b;b>>=1,a=a*a%mod)
if(b&1) r=r*a%mod;
return r;
}
struct matrix
{
ll a[5][5];
matrix(){memset(a,0,sizeof(a));}
matrix operator *(matrix b)
{
matrix re;
for(int i=0;i<5;i++)
for(int j=0;j<5;j++)
for(int k=0;k<5;k++)
up(re.a[i][j],a[i][k]*b.a[k][j]);
return re;
}
void init()
{
a[0][0]=z-4; a[0][1]=0; a[0][2]=0; a[0][3]=2; a[0][4]=2;
a[1][0]=0; a[1][1]=z-2*n+3;a[1][2]=1; a[1][3]=2*n-4; a[1][4]=0;
a[2][0]=0; a[2][1]=1; a[2][2]=z-2*n+3;a[2][3]=0; a[2][4]=2*n-4;
a[3][0]=n-3; a[3][1]=1; a[3][2]=0; a[3][3]=z-n; a[3][4]=2;
a[4][0]=n-3; a[4][1]=0; a[4][2]=1; a[4][3]=2; a[4][4]=z-n;
}
}T;
matrix matksm(matrix a,ll b)
{
matrix r;
for(int i=0;i<5;i++)
r.a[i][i]=1;
for(;b;b>>=1,a=a*a)
if(b&1) r=r*a;
return r;
}
struct bit
{
int c[N];
void add(int x,ll d)
{
for(;x<=n;x+=(x&-x))
up(c[x],d);
}
ll qry(int x)
{
ll r=0;
for(;x;x-=(x&-x))
up(r,c[x]);
return r;
}
}t1,t2;
ll cal(ll num,ll sum)
{
ll re=0;
cnt[1]=0;cnt[2]=num;cnt[3]=(num*(n-1)-sum+mod)%mod;cnt[4]=(num*(n-3)+sum)%mod;
cnt[0]=((z*num-cnt[1]-cnt[2]-cnt[3]-cnt[4])%mod+mod)%mod;
for(int k=0;k<5;k++)
up(re,T.a[k][1]*cnt[k]);
return re;
}
int main()
{
n=read();K=read();
for(int i=1;i<=n;i++)
a[i]=read();
z=((ll)n*(n-1)/2)%mod;
iz=ksm(z,mod-2);
T.init();
T=matksm(T,K);
ll cnt[5],num=0,sum=0,P=0;
for(int i=1;i<=n;i++)
{
ll tmp=t1.qry(a[i]);
up(num,tmp);
up(sum,tmp*i-t2.qry(a[i])%mod+mod);
t1.add(a[i],1);
t2.add(a[i],i);
}
up(ans,cal(num,sum));
num=z-num;sum=-sum;
for(int i=1;i<n;i++)
up(sum,(ll)i*(n-i));
up(ans,ksm(z,K)*num%mod-cal(num,sum)+mod);
printf("%lld",ans);
return 0;
}