题目
https://gmoj.net/senior/#main/show/3984
https://www.luogu.com.cn/problem/P2162
题解
大佬们说是Pόlya定理裸题,但是我太菜了没学会那么高深的定理……
令
f
i
f_i
fi表示不考虑去重,构造一个长度为 i 的串的方案。
那么
f
i
=
∑
j
=
1
17
(
−
1
)
j
⋅
(
17
j
)
⋅
i
j
f_i=\sum_{j=1}^{17}(-1)^j\cdot\tbinom{17}{j}\cdot i^j
fi=j=1∑17(−1)j⋅(j17)⋅ij
似乎输出
f
n
n
\frac{f_n}{n}
nfn就好了,但是这样子显然是有问题的 ,因为如果这样子就能切,出题人也太友善了吧 。
假如现在构造出了一个串1,2,3,…,16,17,1,2,3,…,16,17…,这个串不仅仅是出现了n次,所以这样子做会错。
那么考虑枚举每个重复串的长度(如上面的重复串长度为17),最显而易见的式子是(为什么要乘上n/k?这是为了抵消外面的除以n):
a
n
s
=
∑
k
∣
n
f
k
⋅
n
k
n
ans=\cfrac{\sum_{k|n}f_k\cdot\frac{n}{k}}{n}
ans=n∑k∣nfk⋅kn
然而一个长度为 i 的重复串可能会被长度为2i,3i……的串重复计算,因此
a
n
s
=
∑
k
∣
n
f
k
⋅
φ
(
n
k
)
n
ans=\cfrac{\sum_{k|n}f_k\cdot\varphi(\frac{n}{k})}{n}
ans=n∑k∣nfk⋅φ(kn)
接下来就是喜闻乐见的高精度运算啦!
其实这题没必要打高精度除以单精度,可以把答案处理成
p
⋅
n
+
q
p\cdot n+q
p⋅n+q的形式,其中p用高精度处理。接着如果计算没有出问题的话,最后必定有q=0,因此输出p的前120位就好了。
CODE
这里我把高精度和pn+q封装了一下,因此代码略长。
#include<cmath>
#include<cstdio>
#include<cstring>
using namespace std;
#define ll long long
#define P 1000000
#define M 32005
#define N 55
int C[25][25],pri[M],p[35],n,s;
bool b[M];
inline int max(int x,int y){return x>y?x:y;}
inline ll div(ll x,int y){ll z=x/y;return z+(z*y<x);}
struct bignum
{
ll a[N];int len;
bignum(){memset(a,0,sizeof a),len=1;}
inline void print()
{
for(int i=20;i;--i) printf("%06lld",a[i]);
putchar('\n');
}
inline void del(){while(len>20) a[len]=0,--len;}
}_0;
inline bignum plus(bignum x,bignum y)
{
bignum z;
z.len=max(x.len,y.len);
for(int i=1;i<=z.len;++i)
{
z.a[i]+=x.a[i]+y.a[i];
if(z.a[i]>=P) z.a[i+1]+=z.a[i]/P,z.a[i]%=P;
}
if(z.a[z.len+1]) ++z.len;
z.del();return z;
}
inline bignum minus(bignum x,bignum y)
{
for(int i=1;i<=x.len;++i)
{
x.a[i]-=y.a[i];
if(x.a[i]<0)
{
ll tmp=div(-x.a[i],P);
x.a[i+1]-=tmp,x.a[i]+=tmp*P;
}
}
while(x.len>1&&!x.a[x.len]) --x.len;
return x;
}
inline bignum times(bignum x,bignum y)
{
bignum z;z.len=x.len+y.len-1;
for(int i=1;i<=x.len;++i)
for(int j=1;j<=y.len;++j)
z.a[i+j-1]+=x.a[i]*y.a[j],
z.a[i+j]+=z.a[i+j-1]/P,
z.a[i+j-1]%=P;
for(int i=1;i<=z.len;++i)
z.a[i+1]+=z.a[i]/P,z.a[i]%=P;
while(z.a[z.len+1])
++z.len,z.a[z.len+1]+=z.a[z.len]/P,z.a[z.len]%=P;
z.del();return z;
}
inline bignum plus(bignum x,ll y)
{
x.a[1]+=y;
for(int i=1;i<=x.len&&x.a[i]>=P;++i)
++x.a[i+1],x.a[i]-=P;
while(x.a[x.len+1])
++x.len,x.a[x.len+1]+=x.a[x.len]/P,x.a[x.len]%=P;
x.del();return x;
}
inline bignum minus(bignum x,ll y)
{
ll tmp;x.a[1]-=y;
for(int i=1;i<=x.len&&x.a[i]<0;++i)
{
tmp=div(-x.a[i],P);
x.a[i+1]-=tmp,x.a[i]+=tmp*P;
}
while(x.len>1&&!x.a[x.len]) --x.len;
x.del();return x;
}
inline bignum times(bignum x,ll y)
{
bignum z;z.len=x.len;
for(int i=1;i<=x.len;++i)
z.a[i]+=x.a[i]*y,
z.a[i+1]+=z.a[i]/P,z.a[i]%=P;
while(z.a[z.len+1])
++z.len,z.a[z.len+1]+=z.a[z.len]/P,z.a[z.len]%=P;
z.del();return z;
}
struct number
{
bignum p;ll q;
number(){p=_0,q=0;}
inline void update(){if(q>=n) p=plus(p,q/n),q%=n;}
}ans;
inline number times(number x,ll y)
{
x.p=times(x.p,y),x.q*=y;
x.update();return x;
}
inline number times(number x,number y)
{
x.p=plus(times(times(x.p,y.p),n),plus(times(x.p,y.q),times(y.p,x.q))),x.q*=y.q;
x.update();return x;
}
inline number plus(number x,number y)
{
x.p=plus(x.p,y.p),x.q+=y.q;
x.update();return x;
}
inline number minus(number x,number y)
{
x.p=minus(x.p,y.p),x.q-=y.q;
if(x.q<0) x.p=minus(x.p,1),x.q+=n;
return x;
}
inline number pow(int X,int y)
{
number s,x;
x.q=X,x.update(),s.q=1;
while(y)
{
if(y&1) s=times(s,x);
x=times(x,x),y>>=1;
}
return s;
}
inline int phi(int k)
{
int res=k,i;
for(i=1;i<=s;++i) if(k%p[i]==0)
res=res/p[i]*(p[i]-1);
return res;
}
inline void calc(int p)
{
number num;
for(int i=1;i<18;i+=2) num=plus(num,times(pow(i,p),C[17][i]));
for(int i=2;i<18;i+=2) num=minus(num,times(pow(i,p),C[17][i]));
ans=plus(ans,times(num,phi(n/p)));
}
int main()
{
int m,i,j;
scanf("%d",&n);
C[0][0]=1;
for(i=1;i<18;++i)
{
C[i][0]=C[i][i]=1;
for(j=1;j<i;++j) C[i][j]=C[i-1][j]+C[i-1][j-1];
}
for(i=2;i<M;++i)
{
if(!b[i]) pri[++pri[0]]=i;
for(j=1;j<=pri[0]&&i*pri[j]<M;++j)
{b[i*pri[j]]=1;if(i%pri[j]==0) break;}
}
m=n;
for(i=1;i<=pri[0];++i)
if(m%pri[i]==0)
{
p[++s]=pri[i];
while(m%pri[i]==0) m/=pri[i];
}
if(m>1) p[++s]=m;m=sqrt(n);
for(i=1;i<=m;++i) if(n%i==0)
{calc(i);if(i!=n/i) calc(n/i);}
ans.p.print();
return 0;
}