G、(2019 ACM/ICPC 全国邀请赛(西安)B) Product
Weblink
https://nanti.jisuanke.com/t/39269
Problem && Solution
%**Hint**因为 mod 不一定是质数,所以没办法求 2 的逆元,所以应该把 `/2` 算出来之后再膜(手贱多写了一个mod最开始还没看出来,所以我为什么会犯这么低级的错误然后找了半个小时才找到这个bug...我好菜啊
AC Code
#include <bits/stdc++.h>
using namespace std;
const int N = 5e6 + 6;
int mod;
#define int long long
#define mult(x, y) (1ll * x * y >= mod ? 1ll * x * y % mod : 1ll * x * y)
#define minus(x, y) (1ll * x - y < 0 ? 1ll * x - y + mod : 1ll * x - y)
#define plus(x, y) (1ll * x + y >= mod ? 1ll * x + y - mod : 1ll * x + y)
#define ck(x) (x >= mod : x - mod : x)
typedef long long ll;
ll n, m, p;
ll primes[N], cnt, num[N], d[N];
ll phi[N];
//ll sum[N];// x * d(x)
bool vis[N];
unordered_map<ll, ll> M_sum;
unordered_map<ll, ll> M_phi;
void init(ll n)
{
d[1] = 1;
phi[1] = 1;
for(ll i = 2; i <= n; ++ i) {
if(vis[i] == 0) {
primes[ ++ cnt] = i;
phi[i] = i - 1;
d[i] = 2, num[i] = 1;
}
for(ll j = 1; j <= cnt && i * primes[j] <= n; ++ j) {
vis[i * primes[j]] = true;
if(i % primes[j] == 0) {
phi[i * primes[j]] = phi[i] * primes[j];
num[i * primes[j]] = num[i] + 1;
d[i * primes[j]] = (d[i] / num[i * primes[j]] * (num[i * primes[j]] + 1)) % mod;
break;
}
phi[i * primes[j]] = phi[i] * (primes[j] - 1);
num[i * primes[j]] = 1;
d[i * primes[j]] = (d[i] * 2) % mod;
}
}
for(ll i = 1; i <= n; ++ i) {
phi[i] = (phi[i] + phi[i - 1]) % mod;
d[i] = d[i] * i % mod;
d[i] = (d[i] + d[i - 1]) % mod;
}
}
ll qpow(ll a, ll b, ll mod)
{
ll res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
inline int g_sum(int x)
{
return x;
}
inline ll get_sum_phi(int x)
{
if(x <= N - 7)return phi[x];
if(M_phi[x]) return M_phi[x];
ll ans = mult(x, (1ll * x + 1) / 2);
ll res = 0;
for(ll l = 2, r; l <= x; l = r + 1) {
r = x / (x / l);
res = plus(res, 1ll * (g_sum(r) - g_sum(l - 1)) * get_sum_phi(x / l)) % mod;
}
return M_phi[x] = minus(ans, res) % mod;
}
inline ll get_sum_sum(ll x)
{
if(x <= N - 7) return d[x];
if(M_sum[x]) return M_sum[x];
ll res = 0;
for(ll l = 1, r; l <= x; l = r + 1) {
r = x / (x / l);
//\sum_k=l^r = 平均值乘上长度 (公差为1的等差数列)
res = (res + (1ll * (l + r) * (r - l + 1) / 2) % mod * ((1ll + x / l) * (x / l) / 2) % mod) % mod;
//res = (res + (1ll * (l + r) * (r - l + 1) / 2) % mod * (1ll + x / l) % mod * (x / l) / 2 % mod) % mod;
}
return M_sum[x] = res;
}
signed main()
{
scanf("%lld%lld%lld", &n, &m, &p);
mod = p - 1;
init(N - 7);
ll ans = 0;
for(ll l = 1, r; l <= n; l = r + 1) {
r = n / (n / l);
ans = plus(ans, 1ll * get_sum_phi(n / l) * minus(get_sum_sum(r), get_sum_sum(l - 1)) % mod) % mod;
}
ans = plus(ans, ans) % mod;
ans = minus(ans, get_sum_sum(n));
printf("%lld\n", qpow(m, ans, p) % p);
return 0;
}