题目链接(上面是MenciOJ,下面是Hackerrank ):
https://www.hackerrank.com/contests/101hack43/challenges/k-inversion-permutations
其实还有一个bzoj2431也是一样的题目,但是数据范围更小,允许O(nk)的平方级算法。
题意:给出n,k,求逆序对数为k的1-n的排列的数量. n,k<=10^5.
题解:先来考虑一下bzoj2431的做法,这个做法是O(nk)的,
考虑一下dp:dp[i][j]表示i个数形成排列,逆序对数为j的方案数,
dp转移为:dp[i][j]=sigma(dp[i-1][j-k],k=0..min(j,i-1)).
前缀和优化一下即可得到O(nk)的解法。
dp方程解释:对于前i个数,我们不关心它们的具体大小,只关心相对大小,
因为相对大小就可以确定一个序列的逆序对数量,而指数级别的状态数被减少到了平方级别,
这样就可以把它们离散到一个[1,i]的区间上,通过枚举下一个位置选择的数在这个[1,i]排列中的位置(也就是与前(i-1)个数的相对大小)进行转移。
但是这个方法的瓶颈在于其还是关心了数字之间的大小关系,而且没有考虑通过计数问题常用的组合数等方式去优化。
考虑上述dp方程的实际意义,其实就是求数列{ai}的组数,满足sigma(ai)=k且0<=ai<=i-1.
于是根据这个思路考虑优化的做法:
因为如果没有0<=ai<=i-1的限制答案就是一个组合数,
考虑利用容斥原理去干掉这个限制,
如果我们钦定至少有m个ai满足ai>=i,
设这些i的和为s, 则方案数为满足n个非负数和为(k-s)的方案数,
即为C(k-s+n-1,n-1),预处理后可以O(1)计算。
而由容斥原理可以发现,它对答案的贡献的系数是(-1)^m.
因为如果有p个不满足条件(p>0),
它会在sigma(C(p,2k))个位置做出1的贡献,
在sigma(C(p,2k+1))个位置做出-1的贡献,
在p>0时,其可以两两抵消,最终的贡献是0,
而在p=0时,其贡献为C(0,0)=1.
我们可以dp出选出m个i,和为s的方案数,然后组合数求解,
dp过程中,因为有dp[i][j]=dp[i][j-i]+dp[i-1][j-i]-dp[i-1][j-n-1],所以每个dp值都可以O(1)计算,
dp方程的解释:i个数,和为j-i的情况下,每个数加1可以形成i个数,和为j的状况,
但是,如果原来选出的数字中有n,是不合法的,所以减去,
同时,对于i个数和为j,有一个1被选择的情况没有考虑到,所以加上。
第一维的状态数是sqrt(k)的,所以可以O(k^1.5)求出最终的答案。
Code:
#include <bits/stdc++.h>
#define ll long long
#define mod 1000000007ll
using namespace std;
ll inv[200005],fac[200005];
ll dp[100005],prv[100005];
ll c[100005],p[100005];
inline void pre()
{int i;
fac[0]=1ll;fac[1]=1ll;
for (i=2;i<=200000;i++)
{fac[i]=fac[i-1]*((ll)(i))%mod;}
inv[0]=1ll;inv[1]=1ll;
for (i=2;i<=200000;i++)
{inv[i]=(mod-mod/i)*inv[mod%i]%mod;}
for (i=2;i<=200000;i++)
{inv[i]=inv[i]*inv[i-1]%mod;}
}
inline ll cal(int n,int m)
{return fac[n]*inv[m]%mod*inv[n-m]%mod;}
int main (){
int n,k,i,j;
pre();
scanf ("%d%d",&n,&k);
for (i=0;i<=k;i++)
{c[i]=cal(i+n-1,n-1);p[i]=0;}
dp[0]=1;p[k]=1;
for (i=1;i<=450;i++)
{int tag=i&1;
for (j=0;j<=k;j++) prv[j]=dp[j],dp[j]=0;
for (j=i*(i+1)/2;j<=k;j++)
{dp[j]=dp[j-i]+prv[j-i];
if (j-n-1>=0) {dp[j]-=prv[j-n-1];}
while (dp[j]>=mod) {dp[j]-=mod;}
if (dp[j]<0) {dp[j]+=mod;}
if (!tag)
{p[k-j]+=dp[j];}
else
{p[k-j]-=dp[j];}
}
}
ll ans=0;
for (i=0;i<=k;i++)
{p[i]%=mod;
if (p[i]<0) {p[i]+=mod;}
ans+=(c[i]*p[i])%mod;
if (ans>=mod) {ans-=mod;}
}
printf ("%lld\n",ans);
return 0;
}