从一个简单的问题说起:
给出整数m,n和p,要求计算(m ^ n) % p的结果。
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1;
for (long long i = 0; i < n; i++) {
ans = ans * m;
}
cout << ans << "\n";
return 0;
}
这个程序似乎正确了,但是存在严重问题:
<1>.m或n太大,极容易溢出.
<2>.如果n的值太大,时间消耗O(n)代价较大.
首先解决溢出的问题:
显然:
(a * b) % c = ((a % c) * (b % c)) % c.
这样,就可以把程序改写为如下形式:
但是,如果n的值太大,时间消耗O(n)代价太大,这个问题如何解决呢?
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1;
for (long long i = 0; i < n; i++) {
ans = ((ans % p) * (m % p)) % p;
}
cout << ans << "\n";
return 0;
}
乘方快速幂:
假设要计算 m^10,m^10 = (m^5) ^ 2 = (m * (m ^ 2) ^ 2) ^ 2.
也就是说,要计算m ^ n,有:
那么,程序就变成了:
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1;
while (n) {
if (n % 2 != 0) {
ans = ((ans % p) * (m % p)) % p;
}
n = n / 2;
m = ((m % p) * (m % p)) % p;
}
cout << ans << "\n";
return 0;
}
但是,对于这个程序,我们仍可以继续对其优化:
首先介绍一下 按位与运算(&) 与 右移运算(>>):
<1>.按位与运算:
对于两个二进制数,它们按位与运算的结果是: 对于每一位,如果两个数的这一位同时为1,那么按位与的结果便是1,否则为0,最后将结果转化为十进制,就是我们想要的答案了。 对于一个整数,如果它是奇数,那么它的二进制表示的最低位为1,否则为0,那么对于奇数而言,其按位与1的结果是1,对于偶数而言,其按位与1的结果是0,由此我们可以通过判断一个整数按位与1的结果来判断其是偶数还是奇数.
<2>.右移运算:
同样是对2进制数进行处理,将所有位置上的数字右移,高位补0:如5:101,右移一位为010,结果是2。则:对于一个整数而言,右移一位,相当于其除以2并向下取整。
我们可以根据这两个运算来初步优化程序:
即将 n % 2 != 0 改为 n & 1 == 1,将 n = n / 2 改为 n = n >> 1.
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1;
while (n) {
if (n & 1) {
ans = ((ans % p) * (m % p)) % p;
}
n = n >> 1;
m = ((m % p) * (m % p)) % p;
}
cout << ans << "\n";
return 0;
}
对于m ^ 0,结果为1,1 % 1 == 0,所以,我们应该要防止这种特殊情况,即在进行乘方运算之前,先将ans % p:
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1 % p;
while (n) {
if (n & 1) {
ans = ((ans % p) * (m % p)) % p;
}
n = n >> 1;
m = ((m % p) * (m % p)) % p;
}
cout << ans << "\n";
return 0;
}
因为C++内置的最高整数类型是64位,若运算 (a ^ b) % p中的三个变量a,b,p都在10^18级别,则不存在一个可供强制转化的128位整数类型,我们需要一些特殊的处理办法:
进行乘方运算之前,先让m对p取模一次:
#include <iostream>
using namespace std;
int main() {
long long m, n, p;
cin >> m >> n >> p;
long long ans = 1 % p;
m %= p;
while (n) {
if (n & 1) {
ans = ((ans % p) * (m % p)) % p;
}
n = n >> 1;
m = ((m % p) * (m % p)) % p;
}
cout << ans << "\n";
return 0;
}
这样就是最优的形式了。
下面给出几道相关的练习题:
Raising Modulo Numbers
我们可以计算每一项a^b的值,然后将其加起来作为结果:
#include <iostream>
#define i64 long long
i64 qpow(i64 a, i64 b, i64 p) {
i64 ans = 1 % p;
a %= p;
while (b) {
if (b & 1) {
ans = ((ans % p) * (a % p)) % p;
}
b >>= 1;
a = ((a % p) * (a % p)) % p;
}
return ans;
}
int main() {
int t; std::cin >> t;
while (t--) {
i64 M;
std::cin >> M;
i64 H, ans = 0;
std::cin >> H;
for (int i = 0; i < H; i++) {
i64 A, B;
std::cin >> A >> B;
ans = ((ans % M) + (qpow(A, B, M) % M)) % M;
}
std::cout << ans << "\n";
}
return 0;
}
Pseudoprime numbers
题意:
输入p 和 a,如果p不是质数,并且a>1并且(a^p) % p == a % p,那么输出yes,否则输出no
参考代码:
#include <iostream>
using namespace std;
bool isprime(long long n) {
if (n < 2) {
return false;
}
for (int i = 2; i <= n / i; i++) {
if (n % i == 0) {
return false;
}
}
return true;
}
long long qpow(long long m, long long n, long long p) {
long long ans = 1 % p;
while (n) {
if (n & 1) {
ans = ((ans % p) * (m % p)) % p;
}
n = n >> 1;
m = ((m % p) * (m % p)) % p;
}
return ans;
}
int main() {
long long p, a;
while (cin >> p >> a && p && a) {
if (isprime(p) == false && qpow(a, p, p) == a % p && a > 1) {
cout << "yes\n";
} else {
cout << "no\n";
}
}
return 0;
}
方阵快速幂: