注意
该文章原文于 2023-10-28 16:37 发表于洛谷,链接与文章内容在此留作保存。
一道比较简单的莫反练习题。
题意
给定长度为 n n n 的数列 a 1 , a 2 , a 3 , . . . , a n a_1, a_2, a_3, ..., a_n a1,a2,a3,...,an,求:
( ∑ i = 1 n ∑ j = i + 1 n l c m ( a i , a j ) ) m o d 998244353 \left( \sum_{i=1}^n \sum_{j=i+1}^n {\rm lcm}(a_i, a_j) \right) \bmod 998244353 (i=1∑nj=i+1∑nlcm(ai,aj))mod998244353
题解
首先,考虑把问题转化一下。设 s u m = ∑ i = 1 n ∑ j = 1 n l c m ( a i , a j ) sum = \sum_{i=1}^n \sum_{j=1}^n {\rm lcm}(a_i, a_j) sum=∑i=1n∑j=1nlcm(ai,aj),那么有
∑ i = 1 n ∑ j = i + 1 n l c m ( a i , a j ) = s u m − ∑ i = 1 n a i 2 \begin{aligned} \sum_{i=1}^n \sum_{j=i+1}^n {\rm lcm}(a_i, a_j) = \frac{sum - \sum_{i=1}^n a_i}{2} \end{aligned} i=1∑nj=i+1∑nlcm(ai,aj)=2sum−∑i=1nai
所以考虑如何求出 s u m sum sum。发现直接处理 a i a_i ai 并不好做,但是观察到 1 ≤ a i ≤ 1 0 6 1 \leq a_i \leq 10^6 1≤ai≤106,考虑给每个数字开一个桶然后计数。不妨用 c i c_i ci 表示数字 i i i 出现的次数,用 m m m 表示最大值,那么有:
s u m = ∑ i = 1 m ∑ j = 1 m l c m ( i , j ) × c i × c j sum = \sum_{i=1}^m \sum_{j=1}^m {\rm lcm}(i, j) \times c_i \times c_j sum=i=1∑mj=1∑mlcm(i,j)×ci×cj
所以可以莫反:
s u m = ∑ i = 1 m ∑ j = 1 m i j gcd ( i , j ) × c i × c j \begin{aligned} sum &= \sum_{i=1}^m \sum_{j=1}^m \frac{ij}{\gcd(i, j)} \times c_i \times c_j \end{aligned} sum=i=1∑mj=1∑mgcd(i,j)ij×ci×cj
枚举 gcd ( i , j ) \gcd(i, j) gcd(i,j):
∑ d = 1 n ∑ i = 1 ⌊ m d ⌋ ∑ j = 1 ⌊ m d ⌋ i j d 2 d × c i d × c j d [ gcd ( i , j ) = 1 ] = ∑ d = 1 n ∑ i = 1 ⌊ m d ⌋ ∑ j = 1 ⌊ m d ⌋ i j d × c i d × c j d [ gcd ( i , j ) = 1 ] = ∑ d = 1 n ∑ g = 1 ⌊ m d ⌋ μ ( g ) ∑ i = 1 ⌊ m d g ⌋ ∑ j = 1 ⌊ m d g ⌋ i j d g 2 × c i d g × c j d g \begin{aligned} & \sum_{d=1}^n \sum_{i=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} \frac{ijd^2}{d} \times c_{id} \times c_{jd} \ [\gcd(i, j) = 1]\\ =& \sum_{d=1}^n \sum_{i=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} ijd \times c_{id} \times c_{jd} \ [\gcd(i, j) = 1]\\ =& \sum_{d=1}^n \sum_{g=1}^{\lfloor \frac{m}{d} \rfloor} \mu(g) \sum_{i=1}^{\lfloor \frac{m}{dg} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{dg} \rfloor} ijdg^2 \times c_{idg} \times c_{jdg} \end{aligned} ==d=1∑ni=1∑⌊dm⌋j=1∑⌊dm⌋dijd2×cid×cjd [gcd(i,j)=1]d=1∑ni=1∑⌊dm⌋j=1∑⌊dm⌋ijd×cid×cjd [gcd(i,j)=1]d=1∑ng=1∑⌊dm⌋μ(g)i=1∑⌊dgm⌋j=1∑⌊dgm⌋ijdg2×cidg×cjdg
设 T = d g T = dg T=dg:
∑ T = 1 n ∑ d ∣ T μ ( T d ) ∑ i = 1 ⌊ m T ⌋ ∑ j = 1 ⌊ m T ⌋ i j T 2 d × c i T × c j T = ∑ T = 1 n ∑ d ∣ T μ ( T d ) T 2 d ∑ i = 1 ⌊ m T ⌋ ∑ j = 1 ⌊ m T ⌋ i × j × c i T × c j T = ∑ T = 1 n ∑ d ∣ T μ ( T d ) T 2 d ( ∑ i = 1 ⌊ m T ⌋ i × c i T ) 2 \begin{aligned} & \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{T} \rfloor} \frac{ijT^2}{d} \times c_{iT} \times c_{jT}\\ =& \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d} \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{T} \rfloor} i \times j \times c_{iT} \times c_{jT}\\ =& \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d} \left( \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} i \times c_{iT} \right)^2\\ \end{aligned} ==T=1∑nd∣T∑μ(dT)i=1∑⌊Tm⌋j=1∑⌊Tm⌋dijT2×ciT×cjTT=1∑nd∣T∑μ(dT)dT2i=1∑⌊Tm⌋j=1∑⌊Tm⌋i×j×ciT×cjTT=1∑nd∣T∑μ(dT)dT2 i=1∑⌊Tm⌋i×ciT 2
设 f ( T ) = ∑ d ∣ T μ ( T d ) T 2 d , g ( T ) = ∑ i = 1 ⌊ m T ⌋ i × c i T f(T) = \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d}, g(T) = \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} i \times c_{iT} f(T)=∑d∣Tμ(dT)dT2,g(T)=∑i=1⌊Tm⌋i×ciT,这两个都可以在 O ( n ln n ) \mathcal{O}(n \ln n) O(nlnn) 的时间复杂度做出来,那么最后:
s u m = ∑ T = 1 n f ( T ) g 2 ( T ) sum = \sum_{T=1}^n f(T) g^2(T) sum=T=1∑nf(T)g2(T)
CODE:
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
const int maxk = 1e6;
const ll mod = 998244353;
int n, m, tot;
int prime[maxn>>1], mu[maxn];
ll c[maxn], f[maxn], g[maxn];
bool not_prime[maxn];
ll fpm(ll a, ll k) {
ll res = 1;
while(k) {
if(k&1) res = res*a % mod;
a = a*a % mod;
k >>= 1;
}
return res;
}
const ll inv2 = fpm(2, mod-2);
void prework() {
mu[1] = 1;
for(int i = 2; i <= n; i++) {
if(!not_prime[i]) prime[++tot] = i, mu[i] = -1;
for(int j = 1; j <= tot && i*prime[j] <= n; j++) {
not_prime[i*prime[j]] = true;
if(i%prime[j] == 0) {
mu[i*prime[j]] = 0;
break;
}
mu[i*prime[j]] = -mu[i];
}
}
for(int i = 1; i <= n; i++)
for(int j = 1; i*j <= n; j++)
f[i*j] = (f[i*j] + (1ll * mu[j] * i * j % mod * j % mod + mod) % mod) % mod;
for(int i = 1; i <= n; i++)
f[i] = (f[i] + f[i-1]) % mod;
for(int T = 1; T <= n; T++)
for(int i = 1; i <= n/T; i++)
g[T] = (g[T] + 1ll * i * c[i*T] % mod) % mod;
}
int main() {
scanf("%d", &m);
ll res = 0;
for(int i = 1, x; i <= m; i++) {
scanf("%d", &x);
c[x]++;
res = (res - x%mod + mod) % mod;
n = max(n, x);
}
prework();
for(int i = 1; i <= n; i++) {
ll x = (f[i] - f[i-1] + mod) % mod;
res = (res + x * g[i] % mod * g[i] % mod) % mod;
}
printf("%lld\n", res * inv2 % mod);
return 0;
}