分解
UsedToBe (命题人)
基准时间限制:0.5 秒 空间限制:131072 KB 分值: 80
问(1+sqrt(2)) ^n 能否分解成 sqrt(m) +sqrt(m-1)的形式
如果可以 输出 m%1e9+7 否则 输出no
Input
一行,一个数n。(n<=10^18)
Output
一行,如果不存在m输出no,否则输出m%1e9+7
Input示例
2
Output示例
9
这道题需要证明(1+sqrt(2)) ^n 存在m恒有 sqrt(m) +sqrt(m-1) 的形式 ,具体证明在这里就不说了。
(1+sqrt(2))^n 可以化成 a + b*sqrt(2)的形式。
拓展出来有
a b n
1 1 1
3 2 2
7 5 3
17 12 4
…..
等, 这里我们就可以发现规律了,
从n=1开始,就有 an = an-1 +2* bn-1
bn = an-1 + bn-1
因为n的范围在10^18,O(n)的复杂度。
所以我们这里就可以构造矩阵,然后快速幂。
| 1 2 | * | an-1 | = | an |
| 1 1 | | bn-1 | | bn |
这样子就可以在O(log2n)的复杂度快速算出答案了。
得出an, bn还不够。
这样我们还要可以继续找规律,
你会发现n在奇偶的情况下有规律:
偶数 : m = a^2
奇数: m = 2*b^2
我们就可以做完这道题了。
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cctype>
#include<cmath>
#include<ctime>
#include<string>
#include<stack>
#include<deque>
#include<queue>
#include<list>
#include<set>
#include<map>
#include<cstdio>
#include<limits.h>
#define MOD 1000000007
#define fir first
#define sec second
#define fin freopen("/home/ostreambaba/文档/input.txt", "r", stdin)
#define fout freopen("/home/ostreambaba/文档/output.txt", "w", stdout)
#define mes(x, m) memset(x, m, sizeof(x))
#define Pii pair<int, int>
#define Pll pair<ll, ll>
#define INF 1e9+7
#define Pi 4.0*atan(1.0)
#define lowbit(x) (x&-x)
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
typedef long long ll;
typedef unsigned long long ull;
const double eps = 1e-7;
const int maxn = 101;
using namespace std;
//#define time
struct Matrix
{
ull mat[2][2];
void output(){
for(int i = 0; i < 2; ++i){
for(int j = 0; j < 2; ++j){
cout << mat[i][j] << " ";
}
cout << endl;
}
}
void clear(){
mes(mat, 0);
}
void init(){
mes(mat, 0);
for(int i = 0; i < 2; ++i){
mat[i][i] = 1;
}
}
Matrix operator *(const Matrix &b) const{
Matrix tmp;
tmp.clear();
for(int i = 0; i < 2; ++i){
for(int j = 0; j < 2; ++j){
for(int k = 0; k < 2; ++k){
tmp.mat[i][j] = (tmp.mat[i][j] + mat[i][k]*b.mat[k][j])%MOD;
}
}
}
return tmp;
}
};
Matrix fast_mod(ull n, Matrix &base){
Matrix res;
res.init();
while(n){
if(n&1){
res = res*base;
}
base = base*base;
n >>= 1;
}
return res;
}
int main()
{
ull n, m;
cin >> n;
Matrix base = {
1, 2,
1, 1
};
base = fast_mod(n-1, base);
Matrix p = {
1, 0,
1, 0
};
p = base*p;
if(n&1){
m = p.mat[1][0]*p.mat[1][0]*2%MOD;
}
else{
m = p.mat[0][0]*p.mat[0][0]%MOD;
}
cout << m << endl;
return 0;
}