题目描述
一个数列 a n a_n an,已知 a 1 a_1 a1及 a 2 a_2 a2两项。
数列 a n a_n an满足递推式 a n = x × a n − 1 + y × a n − 2 ( n ≥ 3 ) a_n=x×a_{n−1}+y×a_{n−2}(n≥3) an=x×an−1+y×an−2(n≥3)
求 ∑ i = 1 n a i 2 . \sum_{i=1}^na_i^2. i=1∑nai2.
由于答案可能过大,对 1 0 9 + 7 10^9+7 109+7取模。
输入格式
第一行一个整数T,即数据组数。
下面T行,每行5个整数, n , a 1 , a 2 , x , y n,a_1,a_2,x,y n,a1,a2,x,y,含义如上。
输出格式
共T行,每行一个整数,即为每组数据的答案。
输入输出样例
输入 #1
3
5 1 1 1 1
4 3 4 3 2
461564597527246 987489553 321654648 164165256 315648984
输出 #1
40
4193
480929868
说明/提示
对于100%的数据, T = 30000 , 1 ≤ n ≤ 1 0 18 , 1 ≤ a 1 , a 2 , x , y ≤ 1 0 9 T=30000,1≤n≤10^{18},1≤a_1,a_2,x,y≤10^9 T=30000,1≤n≤1018,1≤a1,a2,x,y≤109
解释:尝试矩阵快速幂,
a n 2 = x 2 a n − 1 2 + y 2 a n − 2 2 + 2 x y a n − 1 a n − 2 a_n^2=x^2a_{n-1}^2+y^2a_{n-2}^2+2xya_{n-1}a_{n-2} an2=x2an−12+y2an−22+2xyan−1an−2
a n 2 a_n^2 an2都是线性关系,关键是怎么推导 a n − 1 a n − 2 a_{n-1}a_{n-2} an−1an−2,这里吗我把把它看成一个整体,则
a n a n − 1 = ( x a n − 1 + y a n − 2 ) a n − 2 = x a n − 1 2 + y a n − 1 a n − 2 a_{n}a_{n-1}=(xa_{n-1}+ya_{n-2})a_{n-2}=xa_{n-1}^2+ya_{n-1}a_{n-2} anan−1=(xan−1+yan−2)an−2=xan−12+yan−1an−2
那么我们可以定义矩阵
A = [ 1 1 0 0 0 x 2 y 2 2 x y 0 1 0 0 0 x 0 y ] A=\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & x^2 & y^2 & 2xy \\ 0 & 1 & 0 & 0\\ 0 & x & 0 & y \end{bmatrix} A=⎣⎢⎢⎡10001x21x0y20002xy0y⎦⎥⎥⎤
Y = [ s u m n − 1 a n 2 a n − 1 2 a n a n − 1 ] Y=\begin{bmatrix} sum_{n-1} \\ a_n^2 \\ a_{n-1}^2 \\ a_na_{n-1} \end{bmatrix} Y=⎣⎢⎢⎡sumn−1an2an−12anan−1⎦⎥⎥⎤
则
[ s u m n a n + 1 2 a n 2 a n + 1 a n ] = A n − 1 [ s u m 1 a 2 2 a 1 2 a 2 a 1 ] \begin{bmatrix} sum_{n} \\ a_{n+1}^2 \\ a_{n}^2 \\ a_{n+1}a_{n} \end{bmatrix}=A^{n-1}\begin{bmatrix} sum_{1} \\ a_{2}^2 \\ a_{1}^2 \\ a_{2}a_{1} \end{bmatrix} ⎣⎢⎢⎡sumnan+12an2an+1an⎦⎥⎥⎤=An−1⎣⎢⎢⎡sum1a22a12a2a1⎦⎥⎥⎤
直接矩阵快速幂
#include<cstdio>
using namespace std;
const int mod=1e9+7;
void mul(long long a[][5],long long b[][5]){
long long c[5][5]={0};
for(int i=1;i<=4;i++){
for(int j=1;j<=4;j++){
for(int k=1;k<=4;k++){
c[i][j]+=a[i][k]*b[k][j]%mod;
c[i][j]%=mod;
}
}
}
for(int i=1;i<=4;i++) for(int j=1;j<=4;j++) a[i][j]=c[i][j];
}
void pow(long long a[][5],long long b){
long long ret[5][5]={0};
ret[1][1]=ret[2][2]=ret[3][3]=ret[4][4]=1;
while(b){
if(b&1){
mul(ret,a);
}
mul(a,a);b>>=1;
}
for(int i=1;i<=4;i++) for(int j=1;j<=4;j++) a[i][j]=ret[i][j];
}
long long a[5][5]={0};
long long T=0;
long long n=0,a1=0,a2=0,x=0,y=0;
long long ret=0;
int main(){
scanf("%lld",&T);
while(T--){
scanf("%lld%lld%lld%lld%lld",&n,&a1,&a2,&x,&y);
x%=mod;y%=mod;a1%=mod;a2%=mod;
a[1][1]=1;a[1][2]=1;a[1][3]=0;a[1][4]=0;
a[2][1]=0;a[2][2]=x*x%mod;a[2][3]=y*y%mod;a[2][4]=2*x%mod*y%mod;
a[3][1]=0;a[3][2]=1;a[3][3]=0;a[3][4]=0;
a[4][1]=0;a[4][2]=x;a[4][3]=0;a[4][4]=y;
pow(a,n-1);
ret=(a[1][1]*a1%mod*a1%mod+a[1][2]*a2%mod*a2%mod)%mod+(a[1][3]*a1%mod*a1%mod+a[1][4]*a2%mod*a1%mod)%mod;
ret%=mod;
printf("%lld\n",ret);
}
return 0;
}