LYK loves girls
题解
首先看到这道题,应该是很容易想到去dp的。
但由于要在一个环上处理掉循环同构的情况,所以我们还得对最后的到的dp处理一下。
先破环为链,由于链的前端与后端是连在一起的,需要记录一下这两块的W长度。
记
d
p
i
,
j
,
k
dp_{i,j,k}
dpi,j,k表示总共
i
i
i个位置,前端有
j
j
j个W,后端有
k
k
k个W的情况数。
方程式是很好想的
d
p
i
,
j
,
k
=
d
p
i
,
j
,
k
−
1
,
d
p
i
,
j
,
0
=
∑
k
=
0
K
d
p
i
,
j
,
k
dp_{i,j,k}=dp_{i,j,k-1},dp_{i,j,0}=\sum_{k=0}^{K}dp_{i,j,k}
dpi,j,k=dpi,j,k−1,dpi,j,0=∑k=0Kdpi,j,k
对于每个
j
+
k
j+k
j+k小于
K
K
K的dp值都是需要记入答案的。
但每个字符串的循环同构在我们当前记录的答案中出现次数都为
n
n
n吗?显然不是的。
由于我们的答案记录的是所有不同的字符串数量,所以它的出现次数应为它的最小循环节长度。
所以对于每种字符串,我们要除去的应为它的最小循环节长度。
所以,我们要进一步将
d
p
i
,
j
,
k
dp_{i,j,k}
dpi,j,k所代表的字符串转化为最小循环节为
i
i
i的长度为
n
n
n的字符串。
上面操作可以用容斥解决。很明显,一个长度为
a
a
a的字符串如果合法的话,那它重复
p
p
p次得到的长度为
a
p
ap
ap的字符串也是合法的,我们需要从
a
p
ap
ap中减去所有这样的
a
a
a。
类似卷积的做法可以让时间复杂度做到
O
(
K
2
n
l
n
n
)
O\left(K^2nln\,n\right)
O(K2nlnn),有
60
p
t
s
60pts
60pts。
考虑优化。
很明显,我们没必要对于每个字符串都有一个单独针对的dp。
如果,我们记
f
i
f_{i}
fi表示长度为
i
i
i的末尾为M的合法字符串个数,因为循环同构是一定可以得到一种末尾为M的情况,我们其实是可以通过这些字符串变形出其它字符串。
如果我们要得到长度为
l
e
n
len
len的合法字符串长度,可以用
∑
i
=
0
K
(
i
+
1
)
f
l
e
n
−
i
−
1
\sum_{i=0}^{K}(i+1)f_{len-i-1}
∑i=0K(i+1)flen−i−1来得到。
因为
f
f
f并未不能表示不全为M的循环字符串长度,我们需要通过在后面填上
i
i
i个W与一个M,在将这些依次从队尾移动至对首来得到所有的循环同构结构。
容易发现,任意一个长度为
l
e
n
len
len的合法字符串,在其中出现且仅出现一次。
顺便还可以在用前缀和优化一下
f
f
f的转移。
这样,就可以优化掉一个
K
K
K了。
不过好像并没有什么*用,还是会T。
究其根本原因,还是因为我们对于每一个
n
n
n的约数,都会存在一个容斥一样的家伙。
由于循环节长度不同,它的出现次数达不到
n
n
n,所以必定会产生一个容斥。
让他的出现次数达到
n
n
n,这样就可以一刀切了。
于是,我们对于每一个长度为
p
i
p_{i}
pi的,需要出现
n
p
i
\frac{n}{p_{i}}
pin次,其中
n
n
n是
p
i
p_{i}
pi的倍数。
我们忽然发现,这个数恰好是
n
n
n以内
p
i
p_{i}
pi的倍数的个数。
所以,我们可以对于每一个
i
i
i,直接加上
d
p
(
n
,
i
)
dp_{(n,i)}
dp(n,i)次。此时的
d
p
(
n
,
i
)
dp_{(n,i)}
dp(n,i)是不考虑最小循环节是否为
(
n
,
i
)
(n,i)
(n,i)的。
由于
p
i
p_{i}
pi是
n
n
n的约数,所以,所有的
(
n
,
i
)
(n,i)
(n,i)中刚好有
n
p
i
\frac{n}{p_{i}}
pin个包含
p
i
p_{i}
pi。
于是,我们可以直接将所有的
d
p
(
n
,
i
)
dp_{(n,i)}
dp(n,i)相加即可。
这些
(
n
,
i
)
(n,i)
(n,i)是有许多重复的,重复的我们可以记忆化,没必要在求了。
所以最后的总时间复杂度就是
O
(
n
+
D
(
n
)
K
)
O\left(n+D(n)K\right)
O(n+D(n)K)了,其中
D
(
n
)
D(n)
D(n)表示
n
n
n的约数个数。
好像说用Bunside引理也可以解释,不过我没看懂它的解释,就自己yy了一个。
源码
为什么我考场上会把容斥打错
注意至少有一个男的,所以我们没必要单独处理所有女生。
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<vector>
#include<queue>
#include<set>
using namespace std;
#define MAXN 100005
#define lowbit(x) (x&-x)
#define reg register
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int,int> pii;
const int mo=1e9+7;
const int INF=0x7f7f7f7f;
const double PI=acos(-1.0);
template<typename _T>
_T Fabs(_T x){return x<0?-x:x;}
template<typename _T>
void read(_T &x){
_T f=1;x=0;char s=getchar();
while(s>'9'||s<'0'){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
int n,K,ans,f[MAXN],inv[MAXN],sum[MAXN],dp[MAXN];
bool vis[MAXN];
int gcd(int a,int b){return !b?a:gcd(b,a%b);}
int add(int x,int y){return x+y<mo?x+y:x+y-mo;}
void init(){inv[1]=1;for(int i=2;i<=n;i++)inv[i]=1ll*(mo-mo/i)*inv[mo%i]%mo;}
signed main(){
freopen("girls.in","r",stdin);
freopen("girls.out","w",stdout);
read(n);read(K);K=min(n,K);init();
f[0]=f[1]=sum[0]=1;sum[1]=2;
for(int i=2;i<=n;i++)
f[i]=(i<K+2)?sum[i-1]:add(sum[i-1],mo-sum[i-K-2]),
sum[i]=add(sum[i-1],f[i]);
for(int i=1;i<=n;i++){
int x=gcd(i,n);
if(!dp[x])
for(int j=0;j<=min(x,K);j++)
dp[x]=add(dp[x],1ll*(j+1)*f[x-j-1]%mo);
ans=add(ans,dp[x]);
}
int tmp=1ll*ans*inv[n]%mo;
printf("%d\n",tmp);
return 0;
}