Goffi and GCD
思路
题目要求 ∑ i = 1 n ∑ j = 1 n g c d ( n − i , n ) g c d ( n − j , n ) = = n k \sum_{i = 1} ^{n} \sum _{j = 1} ^{n} gcd(n - i, n)gcd(n - j, n) == n ^ {k} ∑i=1n∑j=1ngcd(n−i,n)gcd(n−j,n)==nk
显然有 g c d ( n − i , n ) < = n gcd(n - i, n) <= n gcd(n−i,n)<=n对于 k > = 3 k >= 3 k>=3直接可以特判 0 0 0,对于 k = = 2 k == 2 k==2的时候也可以特判一定是 g c d ( 0 , n ) g c d ( 0 , n ) = n 2 gcd(0, n)gcd(0, n) = n ^ 2 gcd(0,n)gcd(0,n)=n2。
所以我们只要考虑 k = = 1 k == 1 k==1的情况:
∑ i = 1 n ∑ j = 1 n g c d ( n − i , n ) g c d ( n − j , n ) = = n k \sum_{i = 1} ^{n} \sum _{j = 1} ^{n} gcd(n - i, n)gcd(n - j, n) == n ^ {k} i=1∑nj=1∑ngcd(n−i,n)gcd(n−j,n)==nk
= ∑ d ∣ n ∑ i = 1 n ∑ i = 1 n ( g c d ( n − i , d ) = = d ) ( g c d ( n − j , n d ) = = n d ) = \sum _{d \mid n} \sum_{i = 1} ^{n} \sum_{i = 1}^{n}(gcd(n - i, d) == d)(gcd(n - j, \frac{n}{d}) == \frac{n}{d}) =d∣n∑i=1∑ni=1∑n(gcd(n−i,d)==d)(gcd(n−j,dn)==dn)
= ∑ d ∣ n ∑ i = 1 n d ∑ i = 1 d ( g c d ( i , d ) = = 1 ) ( g c d ( j , n d ) = = 1 ) = \sum _{d \mid n} \sum_{i = 1} ^{\frac{n}{d}} \sum_{i = 1}^{d}(gcd(i, d) == 1)(gcd(j, \frac{n}{d}) == 1) =d∣n∑i=1∑dni=1∑d(gcd(i,d)==1)(gcd(j,dn)==1)
= ∑ d ∣ n ϕ ( d ) ϕ ( n d ) = \sum_{d\mid n} \phi(d)\phi(\frac{n}{d}) =d∣n∑ϕ(d)ϕ(dn)
最后再特判一下 n = = 1 n == 1 n==1的时候是一定有一个答案的。
代码
/*
Author : lifehappy
*/
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define endl '\n'
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const double pi = acos(-1.0);
const double eps = 1e-7;
const int inf = 0x3f3f3f3f;
ll eular(ll x) {
ll ans = x;
for(ll i = 2; i * i <= x; i++) {
if(x % i == 0) {
while(x % i == 0) {
x /= i;
}
ans = ans / i * (i - 1);
}
}
if(x != 1) ans = ans / x * (x - 1);
return ans;
}
const int mod = 1e9 + 7;
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
ll n, k;
while (scanf("%lld %lld", &n, &k) != EOF) {
ll ans = 0;
if (k == 2 || n == 1) ans = 1;
else if (k == 1) {
for (ll i = 1; i * i <= n; i++) {
if (n % i == 0) {
if (i * i != n) ans = (ans + eular(n/i) * eular(i) * 2) % mod;
else ans = ans = (ans + eular(i) * eular(i)) % mod;
}
}
}
printf("%lld\n", ans);
}
return 0;
}