心得
其实挑战上提了一句这个,只滚系数是O(k²logn)的,但是一直不会实现
今天终于把代码啃下来了,以后带着板子抄就可以了……
适用于dp递推式,O(k³logn)矩阵快速幂超时的场合,
可O(k²)的BM求线性递推式或O(k²logn)的系数矩阵快速幂
思路来源
https://ac.nowcoder.com/acm/contest/view-submission?submissionId=40893084杜教AC代码
https://wenku.baidu.com/view/bac23be1c8d376eeafaa3111.html(叉姐的论文《线性递推关系与矩阵乘法》)
知识整理
如果,
那么,倒着往回推系数,
迟早能推成全是表示的项,再用对应系数一乘一求和就得到一个an的答案
所以思路就是,若要将an用表示,
就要分治的将用
表示,
用
表示,
然后和
对应系数一乘乘出
的系数,
再把这些项每一个用
这样的式子,
倒着从下放系数,直到所有的系数都是用
表示的为止
由于,第一次构造的是的k阶表示,后续不断自乘才得到最高的k阶表示
所以,代码中w和x的作用就是交换高低位,w是不大于n的最大的2的次幂
而当w当前最高位为1时,x当前最低位为1,b是判断现在n中是否有w的这一位
如果有,执行形如,实际操作时把v[i]*v[j]的结果直接加到u[i+j+1]上,
其效果等同于先加到u[i+j]上然后向高一位乘v[1],因为这里v[1]==1恒成立,就省略了
板子整理
以2019牛客暑期多校训练营(第二场)B题.Eddy Walker 2为例
k<=1050,n<=1e18,只能用O(k²)的BM或O(k²logn)的系数矩阵快速幂
#include <bits/stdc++.h>
using namespace std;
#define rep(i,n) for(int i=1;i<=n;++i)
#define mp make_pair
#define pb push_back
#define x0 gtmsub
#define y0 gtmshb
#define x1 gtmjtjl
#define y1 gtmsf
typedef long long ll;
//M为递推项系数个数 c为系数数组
//最终递推式为a[m]=c[0]a[0]+c[1]a[1]+...+c[m-1]a[m-1]
const int M=1050,P=1000000007;
//求快速幂 只写一个系数即求逆元
ll pw(ll x,ll y=P-2){
ll s=1;
for(;y;y>>=1,x=1ll*x*x%P)
if(y&1)s=1ll*s*x%P;
return s;
}
ll i,w,x,b,j,t,a[M],c[M],v[M],u[M<<1],ans;
//推a1...ak每项最终系数的矩阵快速幂 复杂度O(k^2 logn)
//求a^n的k个a^1 a^k的系数 即求a^(n/2)的k个系数 然后乘在一起变成a^1 到a^(2k)的系数
//然后暴力把a^(k+1)到a^(2k)的系数再倒序下放到a^1 到 a^k上
ll sol(ll n,ll m) {//求a[n] a[n]来自前m项递推式
//scanf("%d%d",&n,&m);
n+=m-1;//整体右移m-1项 便于0-(m-1)向负的找系数 应为0
for(i=m-1;~i;i--)c[i]=pw(m);//c[m-1]到c[0] 每个1/m 相当于x^1
for(i=0;i<m-1;i++)a[i]=0;a[m-1]=1;//a[0]-a[m-2]为0 a[m-1]实际为a[0]=0
for(i=0;i<m;i++)v[i]=1;//相当于x^0 (v^2)*v推一次即得c
for(w=!!n,i=n;i>1;i>>=1)w<<=1;//n=0时w=0 n!=0时w>1 w为不大于n的最大2的次幂
for(x=0;w;copy(u,u+m,v),w>>=1,x<<=1){//copy把[u,u+m]复制给v
fill_n(u,m<<1,0),b=!!(n&w),x|=b;//fill_n 把u.begin()的连续m<<1个位置 都覆盖成0
//如果n&w==0 b=0;n&w==w b=1 用两个!把非空判成了1
//如果w最高位为1 则x最低位|=1
if (x<m)u[x]=1;
else {
//如果b==1 说明应u[i+j+1]+=v[i]*v[j] 这里v[1]==1 故省略
//b==1时 直接向高1位加系数 就起到了先向u[i+j]加系数 再整体移到了u[i+j+1]这一位的作用
//类似快速幂的x^(2n+1)=x^(n)*x^(n)*x^1
//每次无论遇到0还是1 都有一个自乘左移的过程 所以要最高位第一次考虑 以此类推
for(i=0;i<m;i++)for(j=0,t=i+b;j<m;j++,t++)u[t]=((ll)v[i]*v[j]+u[t])%P;
for(i=(m<<1)-1;i>=m;i--)for(j=0,t=i-m;j<m;j++,t++)u[t]=((ll)c[j]*u[i]+u[t])%P;//从2k-1到k 下放系数到k-1到0
//注意u[2m-1]=c[0]*u[m-1]+...+c[m-1]*u[2m-2] 每个u[2m-1]的值v可以给u[m-1]下放v*c[0]
}
}
ans=0;
for(i=0;i<m;i++)ans=(1ll*v[i]*a[i]+ans)%P;//推成前m阶系数 第i个系数是v[i] 值为a[i]
return ans;
}
int T,k;
ll n;
int main()
{
for(scanf("%d",&T);T--;)
{
scanf("%d%lld",&k,&n);
if(n==-1)printf("%lld\n",2ll*pw(k+1)%P);
else printf("%lld\n",sol(n,k)%P);//
}
return 0;
}