题意
求下式的值:
Sn=⌈ (a+b√)n⌉%m
其中:
0<a,m<215
0<b,n<231
(a−1)2<b<a2
解析
令:
An=(a+b√)n
Bn=(a−b√)n
Cn=An+Bn
因为: (a−1)2<b<a2
所以: 0<a−b√<1
所以: 0<(a−b√)n<1
即: Bn<1
也就是说, Cn=⌈ An⌉ , Sn=Cn
因此,求 Cn 就行了。
Cn 两边同时乘以 A1+B1 :
Cn∗[(a+b√)+(a−b√)]
=[(a+b√)n+(a−b√)n]∗[(a+b√)+(a−b√)]
=(a+b√)n+1+(a−b√)n+1+(a+b√)n∗(a−b√)+(a−b√)n∗(a+b√)
=Cn+1+(a2−b)∗(a+b√)n−1+(a2−b)∗(a−b√)n−1
=Cn+1+(a2−b)∗Cn−1
所以:
Cn+1=2∗a∗Cn−(a2−b)∗Cn−1
写成矩阵形式:
[Cn+1 Cn ]=[2∗a1−(a2−b)0]∗[C1C0]
至此,公式推导完毕,用快速幂求解就行了。
代码
#pragma comment(linker, "/STACK:1677721600")
#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <climits>
#include <cassert>
#include <iostream>
#include <algorithm>
#define pb push_back
#define mp make_pair
#define LL long long
#define lson lo,mi,rt<<1
#define rson mi+1,hi,rt<<1|1
#define Min(a,b) ((a)<(b)?(a):(b))
#define Max(a,b) ((a)>(b)?(a):(b))
#define mem0(a) memset(a,0,sizeof(a))
#define mem1(a) memset(a,-1,sizeof(a))
#define mem(a,b) memset(a,b,sizeof(a))
#define FIN freopen("in.txt", "r", stdin)
#define FOUT freopen("out.txt", "w", stdout)
using namespace std;
const double eps = 1e-8;
const double ee = exp(1.0);
const int inf = 0x3f3f3f3f;
const int maxn = 1e3 + 10;
const double pi = acos(-1.0);
const LL iinf = 0x3f3f3f3f3f3f3f3f;
int readT()
{
char c;
int ret = 0,flg = 0;
while(c = getchar(), (c < '0' || c > '9') && c != '-');
if(c == '-') flg = 1;
else ret = c ^ 48;
while( c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c ^ 48);
return flg ? - ret : ret;
}
int mod;
typedef vector<LL> vec;
typedef vector<vec> mat;
mat mul(mat &A, mat &B)
{
mat C(A.size(), vec(B[0].size()));
for (int i = 0; i < A.size(); i++)
{
for (int k = 0; k < B.size(); k++)
{
for (int j = 0; j < B[0].size(); j++)
{
C[i][j] = ((C[i][j] + A[i][k] * B[k][j]) % mod + mod)% mod;
}
}
}
return C;
}
mat pow(mat A, LL n)
{
mat B(A.size(), vec(A.size()));
for (int i = 0; i < A.size(); i++)
{
B[i][i] = 1;
}
while (0 < n)
{
if (n & 1)
B = mul(B, A);
A = mul(A, A);
n >>= 1;
}
return B;
}
int main()
{
#ifdef LOCAL
FIN;
#endif // LOCAL
LL a, b, n;
while (cin >> a >> b >> n >> mod)
{
mat A(2, vec(2));
A[0][0] = 2 * a; A[0][1] = b - a * a;
A[1][0] = 1; A[1][1] = 0;
A = pow(A, n - 1);
LL ans = ((2 * a * A[0][0] + 2 * A[0][1]) % mod + mod )% mod;
printf("%d\n", ans);
}
return 0;
}