题目链接
https://www.luogu.org/problemnew/show/P4916
题解
将项链用序列表示,
1
1
1代表黑色,
0
0
0代表白色,对于一个合法序列,它必定能表示成
d
d
d个循环节,每个循环节
n
d
\frac{n}{d}
dn个珠子,其中
m
d
\frac{m}{d}
dm个黑色珠子。假设循环节长度为
n
n
n的合法序列方案数为
f
(
n
)
f(n)
f(n)(不考虑旋转后相同的情况),容易发现答案就是
∑
d
∣
n
1
d
f
(
d
)
\sum_{d|n}\frac{1}{d}f(d)
d∣n∑d1f(d)
定义
g
(
n
)
=
∑
d
∣
n
f
(
d
)
g(n)=\sum_{d|n}f(d)
g(n)=d∣n∑f(d)
容易发现
f
(
n
)
=
∑
d
∣
n
μ
(
n
d
)
g
(
d
)
f(n)=\sum_{d|n}\mu(\frac{n}{d})g(d)
f(n)=d∣n∑μ(dn)g(d)
考虑如何求
g
(
d
)
g(d)
g(d)。显然可以对每个循环节分开考虑,对于长度为
a
a
a,有
b
b
b颗黑珠子的循环节,考虑被
a
−
b
a-b
a−b颗白珠子划分成的黑珠子每一段的数量,
g
(
a
)
g(a)
g(a)就是下面方程的整数解的方案数
∑
i
=
0
a
−
b
x
i
=
b
(
∀
i
∈
[
0
,
a
−
b
]
,
0
≤
x
i
≤
k
,
x
0
+
x
a
−
b
≤
k
)
\sum_{i=0}^{a-b}x_i=b(\forall i\in [0,a-b],0\leq x_i\leq k,x_0+x_{a-b}\leq k)
i=0∑a−bxi=b(∀i∈[0,a−b],0≤xi≤k,x0+xa−b≤k)
容易发现
g
(
a
)
g(a)
g(a)就是下面式子在
x
b
x^b
xb项的系数
(
∑
i
=
0
k
x
i
)
a
−
b
−
1
(
(
∑
i
=
0
k
x
i
)
2
 
m
o
d
 
x
k
+
1
)
(\sum_{i=0}^{k}x^i)^{a-b-1}((\sum_{i=0}^{k}x^i)^2\bmod x^{k+1})
(i=0∑kxi)a−b−1((i=0∑kxi)2modxk+1)
展开右边
(
∑
i
=
0
k
x
i
)
a
−
b
−
1
(
∑
i
=
0
k
(
i
+
1
)
x
i
)
(\sum_{i=0}^k x^i)^{a-b-1}(\sum_{i=0}^{k}(i+1)x^i)
(i=0∑kxi)a−b−1(i=0∑k(i+1)xi)
即
(
1
−
x
k
+
1
1
−
x
)
a
−
b
−
1
(
1
−
(
k
+
2
)
x
k
+
1
+
(
k
+
1
)
x
k
+
2
(
1
−
x
)
2
)
(\frac{1-x^{k+1}}{1-x})^{a-b-1}(\frac{1-(k+2)x^{k+1}+(k+1)x^{k+2}}{(1-x)^2})
(1−x1−xk+1)a−b−1((1−x)21−(k+2)xk+1+(k+1)xk+2)
去括号
(
1
−
x
k
+
1
)
a
−
b
−
1
(
1
−
x
)
−
(
a
−
b
+
1
)
(
1
−
(
k
+
2
)
x
k
+
1
+
(
k
+
1
)
x
k
+
2
)
(1-x^{k+1})^{a-b-1}(1-x)^{-(a-b+1)}(1-(k+2)x^{k+1}+(k+1)x^{k+2})
(1−xk+1)a−b−1(1−x)−(a−b+1)(1−(k+2)xk+1+(k+1)xk+2)
利用二项式定理
(
∑
i
=
0
∞
(
a
−
b
−
1
i
)
(
−
1
)
i
x
(
k
+
1
)
i
)
(
∑
i
=
0
∞
(
a
−
b
+
i
i
)
x
i
)
(
1
−
(
k
+
2
)
x
k
+
1
+
(
k
+
1
)
x
k
+
2
)
(\sum_{i=0}^{\infin}\binom{a-b-1}{i}(-1)^ix^{(k+1)i})(\sum_{i=0}^{\infin}\binom{a-b+i}{i}x^i)(1-(k+2)x^{k+1}+(k+1)x^{k+2})
(i=0∑∞(ia−b−1)(−1)ix(k+1)i)(i=0∑∞(ia−b+i)xi)(1−(k+2)xk+1+(k+1)xk+2)
定义
S
(
n
)
=
∑
(
k
+
1
)
i
+
j
=
n
(
a
−
b
−
1
i
)
(
−
1
)
i
(
a
−
b
+
i
i
)
S(n)=\sum_{(k+1)i+j=n}\binom{a-b-1}{i}(-1)^i\binom{a-b+i}{i}
S(n)=(k+1)i+j=n∑(ia−b−1)(−1)i(ia−b+i)
则有
g
(
a
)
=
S
(
b
)
−
(
k
+
2
)
S
(
b
−
k
−
1
)
+
(
k
+
1
)
S
(
b
−
k
−
2
)
g(a)=S(b)-(k+2)S(b-k-1)+(k+1)S(b-k-2)
g(a)=S(b)−(k+2)S(b−k−1)+(k+1)S(b−k−2)
求
S
(
n
)
S(n)
S(n)可以枚举
i
i
i,时间复杂度就是
O
(
σ
1
(
n
)
k
+
1
)
O(\frac{\sigma_1(n)}{k+1})
O(k+1σ1(n)),由于
σ
1
(
n
)
\sigma_1(n)
σ1(n)不会很大,可以通过此题。
代码
#include <cstdio>
int read()
{
int x=0,f=1;
char ch=getchar();
while((ch<'0')||(ch>'9'))
{
if(ch=='-')
{
f=-f;
}
ch=getchar();
}
while((ch>='0')&&(ch<='9'))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int maxn=100000;
const int mod=998244353;
int p[maxn+10],prime[maxn+10],cnt,mu[maxn+10],fac[maxn+10],inv[maxn+10],ifac[maxn+10];
int getprime()
{
p[1]=mu[1]=1;
for(int i=2; i<=maxn; ++i)
{
if(!p[i])
{
prime[++cnt]=i;
mu[i]=-1;
}
for(int j=1; (j<=cnt)&&(i*prime[j]<=maxn); ++j)
{
int x=i*prime[j];
p[x]=1;
if(i%prime[j]==0)
{
mu[x]=0;
break;
}
mu[x]=-mu[i];
}
}
fac[0]=1;
for(int i=1; i<=maxn; ++i)
{
fac[i]=1ll*fac[i-1]*i%mod;
}
inv[0]=inv[1]=1;
for(int i=2; i<=maxn; ++i)
{
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
ifac[0]=1;
for(int i=1; i<=maxn; ++i)
{
ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
}
return 0;
}
inline int C(int a,int b)
{
if((a<b)||(b<0)||(a<0))
{
return 0;
}
return 1ll*fac[a]*ifac[b]%mod*ifac[a-b]%mod;
}
inline int getsum(int a,int b,int k,int mx)
{
int ans=0;
for(int i=0; mx-(k+1)*i>=0; ++i)
{
int j=mx-(k+1)*i;
ans=(ans+((i&1)?(-1ll):(1ll))*C(a-b-1,i)*C(a-b+j,j))%mod;
if(ans<0)
{
ans+=mod;
}
}
return ans;
}
inline int count(int a,int b,int k)
{
int ans=(getsum(a,b,k,b)-1ll*(k+2)*getsum(a,b,k,b-k-1)+1ll*(k+1)*getsum(a,b,k,b-k-2))%mod;
if(ans<0)
{
ans+=mod;
}
return ans;
}
int n,m,k,f[maxn+10],g[maxn+10];
int getG(int d)
{
g[n/d]=count(n/d,m/d,k);
return 0;
}
int main()
{
getprime();
n=read();
m=read();
k=read();
if(m==0)
{
puts("1");
return 0;
}
for(int i=1; i*i<=m; ++i)
{
if(m%i==0)
{
if(n%i==0)
{
getG(i);
}
int j=m/i;
if((i*i!=m)&&(n%j==0))
{
getG(j);
}
}
}
for(int i=1; i<=n; ++i)
{
for(int j=i; j<=n; j+=i)
{
f[j]+=mu[j/i]*g[i];
if(f[j]>=mod)
{
f[j]-=mod;
}
if(f[j]<0)
{
f[j]+=mod;
}
}
}
int ans=0;
for(int i=1; i<=n; ++i)
{
if(n%i==0)
{
ans=(ans+1ll*inv[i]*f[i])%mod;
}
}
printf("%d\n",ans);
return 0;
}