快速幂
所谓快速幂,就是快速求以a为底的n次幂(即an)的相关操作。
1.低精度的快速幂
1.例如这题:875. 快速幂 - AcWing题库
由数学知识可知:a * b mod m = (a mod m) * (b mod m) mod m。那么就可以将ab mod p转换为若干ak(k = 2n) mod p的乘积。这一操作在二进制上进行,例如:35 mod 2,5的二进制为101,那么就是(3a mod 2) * (3b mod 2) mod 2【a = 20,b = 22】
#define ll long long
ll f(ll a,int b,int c)//a^b%c
{
ll ans = 1;
while(b)
{
if(b&1) ans = ans * a % c;
b >>= 1;
a = a * a % c;
}
return ans;
}
2.那么,如果要求的是最后n位数而不是取模呢?这里假设是三位数。
- 写法一:
ll f(ll a,int b)
{
ll ans = 1;
while(b)
{
if(b&1) ans = ans * a % 1000;
b >>= 1;
a = a * a % 1000;
}
return ans;
}
- 写法二:
ll f(ll a,int b)
{
ll ans = 1;
for(int i = 0;i < b;i++)
{
ans *= b;
ans %= 1000;
}
return ans;
}
2.高精度的快速幂
高精度快速幂常常用于求最后多位数字,例如这题:P1249 最大乘积 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
这题涉及到了贪心算法,这里不加讨论,我们所需要知道的就是它的计算方法是从2开始往上逐步增加,最后减去相应的数即可。例如:15就是2+3+4+5+6-5。13就是2+3+4+5-2+1=2+3+4+6。
首先利用头文件cmath内置的log10()来计算数的位数。
log10(2)* p + 1;
它的快速幂分为三个模块,我们一一来看。
1.主体部分:
int ans[N],f[N],sav[N];
while(p)
{
if(p&1)f1();
p >>= 1;
f2();
}
2.函数f1()
void f1()
{
memset(sav, 0, sizeof sav);
for (int i = 1; i <= 500; i++)
for (int j = 1; j <= 500; j++)
sav[i + j - 1] += ans[i] * f[j];
for (int i = 1; i <= 500; i++)
{
sav[i + 1] += sav[i] / 10;
sav[i] %= 10;
}
memcpy(ans, sav, sizeof sav);
}
3.函数f2()
void f2()
{
memset(sav, 0, sizeof sav);
for (int i = 1; i <= 500; i++)
for (int j = 1; j <= 500; j++)
sav[i + j - 1] += f[i] * f[j];
for (int i = 1; i <= 500; i++)
{
sav[i + 1] += sav[i] / 10;
sav[i] %= 10;
}
memcpy(f, sav, sizeof f);
}
4.初始化与结果减1需注意:
//初始化
f[1] = 2;
ans[1] = 1;
…………
ans[1] -= 1;
下面看看完整代码:
#include<iostream>
#include<cmath>
#include<cstring>
using namespace std;
#define ll long long
const int N = 1005;
int p;
int ans[N], f[N], sav[N];
void f1()
{
memset(sav, 0, sizeof sav);
for (int i = 1; i <= 500; i++)
for (int j = 1; j <= 500; j++)
sav[i + j - 1] += ans[i] * f[j];
for (int i = 1; i <= 500; i++)
{
sav[i + 1] += sav[i] / 10;
sav[i] %= 10;
}
memcpy(ans, sav, sizeof sav);
}
void f2()
{
memset(sav, 0, sizeof sav);
for (int i = 1; i <= 500; i++)
for (int j = 1; j <= 500; j++)
sav[i + j - 1] += f[i] * f[j];
for (int i = 1; i <= 500; i++)
{
sav[i + 1] += sav[i] / 10;
sav[i] %= 10;
}
memcpy(f, sav, sizeof f);
}
int main()
{
cin >> p;
cout << int(log10(2) * p + 1) << endl;
ans[1] = 1;
f[1] = 2;
while (p)
{
if (p & 1)f1();
p >>= 1;
f2();
}
ans[1] -= 1;
for (int i = 500; i >= 1; i--)
{
if (i % 50 == 0 && i != 500)cout << endl;
cout << ans[i];
}
return 0;
}