题目
积性函数,欧拉筛
思路
什么是积性函数?
具体的数学定义这里不赘述,简单来说,积性函数就是 f ( a b ) = f ( a ) f ( b ) f(ab)=f(a)f(b) f(ab)=f(a)f(b),比如本题中的幂运算: ( a b ) n = a n × b n (ab)^n= a^n\times b^n (ab)n=an×bn,所以函数 f ( x ) = x n f(x)=x^n f(x)=xn 就是一个积性函数。
什么是欧拉筛?
具体的数学定义这里不赘述,简单来说,欧拉筛就是用来快速地求小于等于 n n n 的所有素数的。
这道题怎么入手?
首先观察每一项,是幂运算并取模,这就少不了快速幂。
不过就算采用了快速幂算法,如果暴力的求每一项的结果,然后求异或和的话,也是会超时的。
再仔细观察一下
可以发现每一项满足积性函数 f ( x ) = x n f(x)=x^n f(x)=xn,也就是说,不用将每一项都用快速幂来求结果,只需要对素数项求快速幂即可,其余非素数项的结果可以直接由素数项的结果相乘得出。
为什么非得是素数项,我随便拿个数计算出结果,然后再把这个数的所有倍数的结果用相乘的方式算出来不行吗?
因为素数项的结果是无论如何都要用快速幂计算的,因为它没有非 1 1 1 的因数来间接算出,也就无法简化。所以,不管拿什么数用快速幂计算,都逃不过所有素数,所以只用快速幂计算素数项的结果是最优的。
那么怎么求所有素数项的快速幂呢?
使用素数筛选出小于等于 n n n 的所有素数即可,而素数筛中常用的并且效果还不错的就是欧拉筛,这里不赘述算法细节,具体见代码。
注意
在使用欧拉筛的过程中,需要保证 p r i m e s primes primes 中的素数是从小到大的顺序的,避免漏掉数,所以用数组是最好的选择,不能使用无序的哈希集合。
代码
#include <iostream>
#include <numeric> // accumulate()
#include <vector>
using namespace std;
using ULL = unsigned long long;
const int MOD = 1e9 + 7;
/**
* @brief 快速幂,不考虑乘法溢出
*
* @param a 底数
* @param b 指数
* @param p 模
* @return ULL a^b mod p
*/
ULL quick_pow(ULL a, ULL b, ULL p) {
ULL res = 1ULL;
while (b) {
if (b & 1) {
res = (res * a) % p;
}
b >>= 1;
a = (a * a) % p;
}
return res;
}
/**
* @brief 在欧拉筛的基础上解决问题
*
* @param n 题目给定的数
* @return int 答案
*/
int solve(int n) {
vector<int> primes; // 装素数
vector<bool> is_prime(n + 1, true); // 判断是不是素数
vector<int> tmp(n + 1, 0); // 装每一项的结果,每一项是指 i^N mod p
int i = 0, j = 0;
tmp[1] = 1; // 第一项始终为1,1^n = 1
for (i = 2; i <= n; i++) {
// 当 i 是素数时
if (is_prime[i]) {
// 记录下来
primes.emplace_back(i);
// 只计算素数的幂,其余元素的幂可以由素数的幂间接算出
tmp[i] = quick_pow(i, n, MOD);
}
// 遍历当前所有素数,但不能超出 n 的范围
for (j = 0; j < primes.size() && i * primes[j] <= n; j++) {
// 既然都是 i * primes[j] 了,那么必然不是素数,筛去
is_prime[i * primes[j]] = false;
// 积性函数,非素数的幂可以由素数的幂算出,注意乘法溢出
tmp[i * primes[j]] = (ULL)tmp[i] * tmp[primes[j]] % MOD;
// 如果当前数是当前遍历到的素数的倍数,则直接退出循环结束遍历
// 主要是提高效率,避免不必要的计算
if (i % primes[j] == 0) break;
}
}
// 最后将 tmp 的数据全部求异或和,由于异或是不会溢出的,所以不用处理溢出问题
return accumulate(tmp.begin(), tmp.end(), 0,
[](const int res, const int cur) { return res ^ cur; });
}
int main(void) {
int n = 0;
cin >> n;
cout << solve(n) << endl;
return 0;
}