题目大意:给定正整数nn,求有多少个整数四元组
a
,
b
,
c
,
d
∈
[
0
,
n
−
1
]
a,b,c,d\in [0,n-1]
a,b,c,d∈[0,n−1]满足
a
b
=
c
d
(
m
o
d
n
)
ab=cd\pmod n
ab=cd(modn)。由于
n
n
n非常大,将以质因数分解的形式给出
n
=
∏
i
=
1
m
p
i
c
i
n=\prod_{i=1}^mp_i^{c_i}
n=∏i=1mpici。
m
≤
5
×
1
0
5
,
p
,
c
≤
1
0
9
m\le5\times 10^5,p,c\le10^9
m≤5×105,p,c≤109。
题解:由中国剩余定理的结论知我们只需要求出
n
=
p
c
n=p^c
n=pc的答案然后乘起来即可。
考虑这个怎么做,不妨设
c
n
t
i
=
∑
(
a
,
b
)
[
a
b
 
m
o
d
 
n
=
i
]
cnt_i=\sum_{(a,b)}[ab\bmod n=i]
cnti=∑(a,b)[abmodn=i],那么答案就是
∑
i
=
0
n
−
1
c
n
t
i
2
\sum_{i=0}^{n-1}cnt^2_i
∑i=0n−1cnti2。显然
c
n
t
0
=
n
2
−
∑
i
=
1
n
−
1
c
n
t
i
cnt_0=n^2-\sum_{i=1}^{n-1}cnt_i
cnt0=n2−∑i=1n−1cnti,考虑某个
c
n
t
i
(
i
>
0
)
cnt_i(i>0)
cnti(i>0)怎么算。
不妨设
i
=
p
k
q
i=p^kq
i=pkq,满足
gcd
(
p
,
q
)
=
1
\gcd(p,q)=1
gcd(p,q)=1。显然
0
≤
k
<
c
,
q
>
0
0\le k<c,q>0
0≤k<c,q>0。
那么
a
b
=
i
ab=i
ab=i,等价于
a
=
p
k
′
a
′
,
b
=
p
k
−
k
′
b
′
,
gcd
(
a
′
,
p
)
=
gcd
(
b
′
,
p
)
=
1
,
a
′
b
′
=
q
(
m
o
d
p
c
−
k
)
a=p^{k'}a',b=p^{k-k'}b',\gcd(a',p)=\gcd(b',p)=1,a'b'=q\pmod {p^{c-k}}
a=pk′a′,b=pk−k′b′,gcd(a′,p)=gcd(b′,p)=1,a′b′=q(modpc−k),并且每求出这样的一组
(
a
′
,
b
′
,
k
′
)
(a',b',k')
(a′,b′,k′),都会有
p
k
′
×
p
k
−
k
′
=
p
k
p^{k'}\times p^{k-k'}=p^k
pk′×pk−k′=pk组
(
a
,
b
,
k
)
(a,b,k)
(a,b,k)与之对应,而
(
a
′
,
b
′
,
k
′
)
(a',b',k')
(a′,b′,k′)的组数显然就是
ϕ
(
p
c
−
k
)
\phi(p^{c-k})
ϕ(pc−k),与
k
′
k'
k′无关,因此对于每个
k
′
k'
k′,答案就是
p
k
ϕ
(
p
c
−
k
)
p^{k}\phi(p^{c-k})
pkϕ(pc−k),因此
c
n
t
i
=
(
k
+
1
)
p
k
ϕ
(
p
c
−
k
)
=
(
k
+
1
)
ϕ
(
p
c
)
cnt_i=(k+1)p^k\phi(p^{c-k})=(k+1)\phi(p^c)
cnti=(k+1)pkϕ(pc−k)=(k+1)ϕ(pc)。
然后考虑对于每个
k
k
k有多少个
i
i
i,显然就是
ϕ
(
p
c
−
k
)
\phi(p^{c-k})
ϕ(pc−k)
首先考虑
∑
i
=
1
n
−
1
c
n
t
i
2
=
∑
k
=
0
c
−
1
(
k
+
1
)
2
ϕ
2
(
p
c
)
ϕ
(
p
c
−
k
)
=
ϕ
2
(
p
c
)
(
p
−
1
)
p
c
∑
k
=
1
c
k
2
(
1
p
)
k
\sum_{i=1}^{n-1}cnt_i^2=\sum_{k=0}^{c-1}(k+1)^2\phi^2(p^c)\phi(p^{c-k})=\phi^2(p^c)(p-1)p^c\sum_{k=1}^ck^2\left(\frac{1}{p}\right)^k
∑i=1n−1cnti2=∑k=0c−1(k+1)2ϕ2(pc)ϕ(pc−k)=ϕ2(pc)(p−1)pc∑k=1ck2(p1)k
后面那个怎么求:
S
2
=
∑
k
=
1
n
k
2
q
k
S
2
−
n
2
q
n
=
∑
k
=
2
n
(
k
−
1
)
2
q
k
−
1
q
S
2
−
n
2
q
n
+
1
=
∑
k
=
2
n
(
k
2
−
2
k
+
1
)
q
k
q
S
2
−
n
2
q
n
+
1
=
∑
k
=
2
n
k
2
q
k
−
2
∑
k
=
2
n
k
q
k
+
∑
k
=
2
n
q
k
q
S
2
−
n
2
q
n
+
1
=
S
2
−
q
+
2
(
S
1
−
q
)
+
S
0
−
q
S
2
=
n
2
q
n
+
1
−
q
+
2
(
S
1
−
q
)
+
S
0
−
q
q
−
1
S_2=\sum_{k=1}^n k^2q^k\\ S_2-n^2q^n=\sum_{k=2}^n(k-1)^2q^{k-1}\\ qS_2-n^2q^{n+1}=\sum_{k=2}^n(k^2-2k+1)q^k\\ qS_2-n^2q^{n+1}=\sum_{k=2}^nk^2q^k-2\sum_{k=2}^nkq^k+\sum_{k=2}^nq^k\\ qS_2-n^2q^{n+1}=S_2-q+2(S_1-q)+S_0-q\\ S_2=\frac{n^2q^{n+1}-q+2(S_1-q)+S_0-q}{q-1}
S2=k=1∑nk2qkS2−n2qn=k=2∑n(k−1)2qk−1qS2−n2qn+1=k=2∑n(k2−2k+1)qkqS2−n2qn+1=k=2∑nk2qk−2k=2∑nkqk+k=2∑nqkqS2−n2qn+1=S2−q+2(S1−q)+S0−qS2=q−1n2qn+1−q+2(S1−q)+S0−q
其中
S
1
=
∑
k
=
1
n
k
q
k
,
S
0
=
∑
k
=
1
n
q
k
S_1=\sum_{k=1}^nkq^k,S_0=\sum_{k=1}^nq^k
S1=∑k=1nkqk,S0=∑k=1nqk,求法同理。
剩余还有一个
c
n
t
0
2
cnt_0^2
cnt02,过程类似,略。
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 1000000007
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int inv(int x) { return fast_pow(x,mod-2); }
inline lint squ(int x) { return (lint)x*x; }
inline int solve(int p,int c)
{
int n=fast_pow(p,c),t=n-fast_pow(p,c-1);if(t<0) t+=mod;
int q=inv(p),v=fast_pow(q,c),z=inv(q-1),c0=q*(v-1ll)%mod*z%mod;
int c1=((lint)c*v%mod*q-c0)%mod*z%mod;if(c1<0) c1+=mod;
int c2=((lint)c*c%mod*v%mod*q%mod-q-2ll*(c1-q)%mod+(lint)q*(c0-v)%mod)*z%mod;if(c2<0) c2+=mod;
return ((lint)t*t%mod*(p-1)%mod*n%mod*c2+squ((lint)n*n%mod-t*(p-1ll)%mod*n%mod*c1%mod))%mod;
}
int main()
{
int ans=1;
for(int T=inn(),p,c;T;T--) p=inn(),c=inn(),ans=(lint)ans*solve(p,c)%mod;
return !printf("%d\n",ans);
}