AtCoder Grand Contest 005 D - ~K Perm Counting
https://atcoder.jp/contests/agc005/tasks/agc005_d
首先想到如何去容斥,假设任意钦定i个非法点的选择方案是dp[i]
那么 ans= ∑ i = 0 n ( − 1 ) n ∗ d p [ i ] ∗ f a c [ i ] \sum_{i=0}^{n}(-1)^n*dp[i]*fac[i] ∑i=0n(−1)n∗dp[i]∗fac[i]
然后我们考虑如何去求出这个dp[i]
我们可以通过列出二分图来分析,第i个点非法的意思也就是他选了i-k或者选了i+k
我们就会发现互相影响的不能选的其实是一条链,从i,i+k,i+2k…这样一直延续下去
于是我们可以把每条链单独拿出来,f[i,j,0]为链的前i条边选了j条边且第i条边没选的方案数,1为选了的
然后这条链最后选了i条的方案数就是f[len-1,i,0]+f[len-1,i,1]
然后上一条链得到的方案数是dp[(cnt&1)^1] [0-n] ,那么新增一条链,就枚举f [len-1,0-len-1]去更新整个dp数组
由于,由于$\sum {len} $ =2*n,那么实际复杂度就是 O ( n 2 ) O(n^2) O(n2)的
写了100+行某牛逼网友博客27行就写完了。。。等下去学习一下
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxl=2010;
const int mod=924844033;
int n,m,k,cnt,tot,cas;ll ans;
int a[maxl];
ll dp[2][maxl],fac[maxl];
ll f[maxl][maxl][2];
bool vis[maxl*2];
char s[maxl];
inline void prework()
{
scanf("%d%d",&n,&k);
fac[0]=1;
for(int i=1;i<=n;i++)
fac[i]=fac[i-1]*i%mod;
for(int i=1;i<=2*n;i++)
vis[i]=false;
}
inline void add(ll &x,ll y){x+=y;if(x>=mod) x-=mod;}
inline int dfs(int i,int len)
{
vis[i]=true;
if(i<=n)
{
if(i+k>n || vis[i+k+n])
return len;
else
return dfs(i+k+n,len+1);
}
else
{
if(i-n+k>n || vis[i-n+k])
return len;
else
return dfs(i-n+k,len+1);
}
}
inline void calc(int len)
{
for(int i=1;i<=len-1;i++)
for(int j=0;j<=i;j++)
f[i][j][0]=f[i][j][1]=0;
f[0][0][0]=1;f[0][0][1]=0;
for(int i=0;i<=len-2;i++)
for(int j=0;j<=i;j++)
{
add(f[i+1][j][0],f[i][j][0]);
add(f[i+1][j+1][1],f[i][j][0]);
add(f[i+1][j][0],f[i][j][1]);
}
}
inline void mainwork()
{
dp[0][0]=1;cnt=0;
for(int i=1;i<=n;i++)
if(!vis[i])
{
int len=0;
if(i-k>0 && !vis[i-k+n])
{
vis[i-k+n]=true;
len=dfs(i,2);
}
else
len=dfs(i,1);
calc(len);++cnt;int d=cnt&1;
for(int j=0;j<=n;j++)
dp[d][j]=0;
for(int j=0;j<=len-1;j++)
for(int l=n;l-j>=0;l--)
add(dp[d][l],dp[d^1][l-j]*(f[len-1][j][0]+f[len-1][j][1])%mod);
}
ans=0;
for(int i=0;i<=n;i++)
if(i&1)
ans=(ans-fac[n-i]*dp[cnt&1][i]%mod+mod)%mod;
else
ans=(ans+fac[n-i]*dp[cnt&1][i]%mod)%mod;
}
inline void print()
{
printf("%lld\n",ans);
}
int main()
{
int t=1;
//scanf("%d",&t);
for(cas=1;cas<=t;cas++)
{
prework();
mainwork();
print();
}
return 0;
}