题目描述
热情好客的小猴请森林中的朋友们吃饭,他的朋友被编号为 1∼N,每个到来的朋友都会带给他一些礼物:香蕉。其中,第一个朋友会带给他1个香蕉,之后,每一个朋友到来以后,都会带给他之前所有人带来的礼物个数再加他的编号的K次方那么多个。所以,假设 K=2,前几位朋友带来的礼物个数分别是:
1,5,15,37,83,…
假设 K=3,前几位朋友带来的礼物个数分别是:
1,9,37,111,…
现在,小猴好奇自己到底能收到第 N 个朋友多少礼物,因此拜托于你了。
已知 N,K,请输出第 N 个朋友送的礼物个数 mod 1000000007 。
1,5,15,37,83,…
假设 K=3,前几位朋友带来的礼物个数分别是:
1,9,37,111,…
现在,小猴好奇自己到底能收到第 N 个朋友多少礼物,因此拜托于你了。
已知 N,K,请输出第 N 个朋友送的礼物个数 mod 1000000007 。
输入
第一行,两个整数 N,K。
输出
一个整数,表示第N个朋友送的礼物个数 mod 1000000007。
样例输入
4 2
样例输出
37
提示
100% 的数据:N≤1018,K≤10。
来源
分析:这几天做了好几个矩阵快速幂,略有心得;1、数学必须好,不然推不出系数矩阵。
2、大数的运算取余时,能取余就取余而且尽早取余,免得超过范围long long了存不下;
输出的地方也要取余,说不定就卡你某个示例,在没有把握的情况下,都写上比较好。
3、取余时要避免减法带来的负数,通常(a+MOD)%MOD;
现在遇到的情况就这些,以后再补充。
4、很多题目的类型都是if(n<k) ... else ...进入快速幂函数,所以最后输出的时候,必须取余,
虽然在函数里每一步操作都取余,但是也有不进入函数的情况,进入cal函数也有不进入mul计算的情况。
代码如下:
/*
An=Sn-1+n^k;
Sn=An+Sn-1=2Sn-1+n^k
=2Sn-1+C(k,0)*(n-1)^k+C(k,1)*(n-1)^(k-1)+ ..... C(k,k)*(n-1)^0;
题目要求的是An,但是使用矩阵快速幂必须依赖Sn和Sn-1的关系
构造矩阵:
|2 C(x,0) C(x,1) C(x,2) ..... C(x,x)| |Sn-1 | | S(n) |
|0 C(x,0) ...... C(x,x) | |(n-1)^k | | n^k |
|0 0 C(x-1,0)...... C(x-1,x-1) | *|... |=| ... |
|. . . | |(n-1)^2 | | n^2 |
|... | |(n-1)^1 | | n^1 |
|0 0 0 0 0 0 C(0,0) | |(n-1)^0 | | n^0 |
*/
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
#define MAXN 15
#define MOD 1000000007
long long c[15][15],K;
long long N;
struct Node {
long long M[20];
}res;
struct Node_a {
long long M[20][20];
}ori;
void calc_comb() //组合数打表
{
c[0][0] = c[1][0] = c[1][1]=1;
for(int i=2;i<MAXN;i++)
{
c[i][0] = c[i][i] = 1;
for(int j=1;j<i;j++)
c[i][j] = (c[i-1][j]+c[i-1][j-1])%MOD;
}
}
Node mul(Node a,Node_a b) //系数矩阵自乘
{
Node tem;
memset(tem.M,0,sizeof(tem.M));
for(int i=0;i<K+2;i++)
for(int j=0;j<K+2;j++)
tem.M[i]=(tem.M[i]+(ori.M[i][j]*res.M[j])%MOD)%MOD;
return tem;
}
Node_a mul1(Node_a a ,Node_a b)
{
Node_a tem;
memset(tem.M,0,sizeof(tem.M));
for(int i=0;i<K+2;i++)
for(int j=0;j<K+2;j++)
for(int k=0;k<K+2;k++)
tem.M[i][j]=(tem.M[i][j]+(a.M[i][k]*b.M[k][j])%MOD)%MOD;//防止越界
return tem;
}
void calc()
{
memset(ori.M,0,sizeof(ori.M)); ori.M[0][0]=2;
for(int i=1;i<=K+1;i++) //初始化系数矩阵第一行
ori.M[0][i]=c[K][i-1];
for(int i=1;i<K+2;i++) //初始化第二行,往后
for(int j=i;j<K+2;j++)
ori.M[i][j]=c[K+1-i][j-i];
res.M[0]=1;
for(int i=1;i<K+2;i++)
res.M[i]=1;
N-=2;
while(N)
{
if(N&1)
res=mul(res,ori);
N>>=1;
ori=mul1(ori,ori);
}
}
long long ppow(long long a,long long b)
{
a=a%MOD; //提前取余防止越界 ,这里取余后用pow也行因为K<=10;
long long ans=1;
while(b)
{
if(b&1)
ans=(a*ans)%MOD;
b>>=1;
a=(a*a)%MOD;
}
return ans%MOD;
}
int main()
{
calc_comb();
while(~scanf("%lld%lld",&N,&K))
{
long long N1=N;
if(N==1)
printf("1\n");
else
{
calc();
printf("%lld\n",(res.M[0]+ppow(N1,K))%MOD); //避免一切可能超的地方
}
}
return 0;
}