一、什么是快速幂算法?
我们先来看这道题
这道题目乍一看会觉得并不难啊,题目短短一行而已,而且思路也很容易,求幂这种算法一般在初学程序设计语言的时候应该都有联系过,只要写一个简单的循环就能够搞定。(或者用pow函数)
这里假设我们使用循环来写!
#include<iostream>
using namespace std;
long long power(long long a, long long n)
{
long long result = 1;
for (long long i = 1; i <= n; i++)
{
result = result * a;
}
return result;
}
int main()
{
long long a, n, m;
cin >> a >> n >> m;
cout << (power(a, n) % m);
return 0;
}
但是这样真的对吗?
我们不妨来运行一下:
奇了怪了?为什么会变成这样子呢?先不急,我们再来考虑一下,为什么这道题会让我们取模呢?仅仅只是增加难度嘛?实则不然,我们知道”指数爆炸“吧,如果让我们计算2的100次方,他的数值将非常大,我们无法将其存入long long(存不下)。
那这题我们要怎么做呢?我们得解决这个数太大的问题,因此我们需要把大人物分解成小任务,下面我们先来看看取模的运算法则:
(a + b) % p = (a % p + b % p) % p (1)
(a - b) % p = (a % p - b % p ) % p (2)
(a * b) % p = (a % p * b % p) % p (3)
本题我们需要用到的是第三个公式,具体公式的证明思路这里我们就不多讲了,感兴趣的可以自行取了解一下。我们可以借助这个法则,只需要在循环乘积的每一步都提前进行“取模”运算,而不是等到最后直接对结果“取模”,也能达到同样的效果。
所以,我们的代码可以变成这个样子
#include<iostream>
using namespace std;
long long power(long long a, long long n,long long m)
{
long long result = 1;
for (long long i = 1; i <= n; i++)
{
result = result * a;
result = result % m;
}
return result;
}
int main()
{
long long a, n, m;
cin >> a >> n >> m;
cout << (power(a, n, m));
return 0;
}
运行结果:
得到了结果,我们来尝试提交一下。
我们会发现我们超时了!!
我们来考虑一下这个算法的时间复杂度,假设我们求2的100次方,那么将会执行100次循环。如果我们分析一下这个算法,就会发现这个算法的时间复杂度为O(N),其中N为指数。求一下小的结果还好,那如果我们要求很大呢?这个程序可能会运行很久很久!
二、快速幂算法
快速幂算法能帮我们算出指数非常大的幂,传统的求幂算法之所以时间复杂度非常高(为O(指数n)),就是因为当指数n非常大的时候,需要执行的循环操作次数也非常大。所以我们快速幂算法的核心思想就是每一步都把指数分成两半,而相应的底数做平方运算。这样不仅能把非常大的指数给不断变小,所需要执行的循环次数也变小,而最后表示的结果却一直不会变。
因此该题的方法是将 a^n 拆成若干项相乘的形式,即把 n 化成二进制形式,变成2的幂次相加。
下来,再让我们用代码来演示一下上面的算法:
long long power(long long a, long long n, long long m)
{
long long result = 1;
while (n > 0) {
if (n % 2 == 0)
{
//如果指数为偶数
n = n / 2;//把指数缩小为一半
a = a * a % m;//底数变大成原来的平方
}
else
{
//如果指数为奇数
n = n - 1;//把指数减去1,使其变成一个偶数
result = result * a % m;//此时记得要把指数为奇数时分离出来的底数的一次方收集好
n = n / 2;//此时指数为偶数,可以继续执行操作
a = a * a % m;
}
}
return result;
}
明显能过了!!!
我们还可以优化一下
#include<iostream>
using namespace std;
long long f(long long a, long long n, long long m)
{
long long ans = 1;
while (n)
{
if (n & 1)//此处等价于if(n%2==1)
ans = ans * a % m;
a = a * a % m;
n >>= 1;//此处等价于n=n/2
}
return ans % m;
}
int main()
{
long long a, b, m, p = 0;
cin >> a >> b >> m;
p = f(a, b, m);
cout << p;
return 0;
}
经过这样的优化此代码效率极高,大家可以自己去测试测试