链接
题解
这题好像只有我跑了 3 m s 3ms 3ms?他们咋都用了那——么长时间呢,我看有些是矩阵快速幂,有些是 B M BM BM,好像没人算法和我一样…
我的算法其实也非常直接,首先写出斐波那契数列通项:
f n = 1 5 ( 1 + 5 2 ) n − 1 5 ( 1 − 5 2 ) n f_n = \frac{1}{\sqrt 5}(\frac{1+\sqrt 5}{2})^n - \frac{1}{\sqrt 5}(\frac{1-\sqrt 5}{2})^n fn=51(21+5)n−51(21−5)n
为了方便后面叙述,我把上述公式写成 f n = a 1 b 1 n + a 2 b 2 n f_n = a_1 b_1^n +a_2b_2^n fn=a1b1n+a2b2n
然后带入他让求的那个东西:
a n s = ∑ i = 1 n i k f i = a 1 ∑ i = 1 n i k b 1 i + a 2 ∑ i = 1 m i k b 2 i ans = \sum_{i=1}^n i^k f_i \\ = a_1\sum_{i=1}^n i^k b_1^i + a_2\sum_{i=1}^mi^kb_2^i ans=i=1∑nikfi=a1i=1∑nikb1i+a2i=1∑mikb2i
现在问题就成了怎么求 ∑ i = 1 n i k b i \sum_{i=1}^n i^k b^i ∑i=1nikbi
令 f ( x ) = ∑ i = 1 n x i f(x) = \sum_{i=1}^n x^i f(x)=∑i=1nxi
下面开始推公式:
f ′ ( x ) = ∑ i = 1 n i x i − 1 x f ′ ( x ) = ∑ i = 1 n i x i ( x f ′ ( x ) ) ′ = x f ′ ′ ( x ) + f ′ ( x ) = ∑ i = 1 n i 2 x i − 1 x 2 f ′ ′ ( x ) + x f ′ ( x ) = ∑ i = 1 n i 2 x i . . . f'(x) = \sum_{i=1}^ni x^{i-1}\\ xf'(x) = \sum_{i=1}^ni x^i\\ (xf'(x))' = xf''(x) + f'(x) = \sum_{i=1}^ni^2 x^{i-1} \\ x^2f''(x) + xf'(x) = \sum_{i=1}^ni^2 x^i \\ ... f′(x)=i=1∑nixi−1xf′(x)=i=1∑nixi(xf′(x))′=xf′′(x)+f′(x)=i=1∑ni2xi−1x2f′′(x)+xf′(x)=i=1∑ni2xi...
然后就这样继续推下去,就可以得到 ∑ i = 1 n i k x i \sum_{i=1}^ni^kx^i ∑i=1nikxi对应等式左边的样子,这个递推的过程直接交给计算机,可以得到一个 O ( k 2 ) O(k^2) O(k2)的暴力
然后接下来就是怎么求 f ( x ) , f ′ ( x ) , f ′ ′ ( x ) , f ′ ′ ′ ( x ) , f ( 4 ) ( x ) , . . . , f ( k ) ( x ) f(x),f'(x),f''(x),f'''(x),f^{(4)}(x),...,f^{(k)}(x) f(x),f′(x),f′′(x),f′′′(x),f(4)(x),...,f(k)(x)在 x = b x=b x=b处的值
这个还是有点脑洞在里面:
f ( x ) = x − x n + 1 1 − x ( 1 − x ) f ( x ) = x − x n + 1 ( 1 − x ) f ′ ( x ) − f ( x ) = 1 − ( n + 1 ) x n ( 1 − x ) f ′ ′ ( x ) − 2 f ′ ( x ) = − ( n + 1 ) n x n − 1 . . . f(x) = \frac{x-x^{n+1}}{1-x} \\ (1-x)f(x) = x-x^{n+1} \\ (1-x)f'(x) - f(x) = 1 - (n+1)x^n \\ (1-x)f''(x) - 2f'(x) = -(n+1)nx^{n-1} \\ ... f(x)=1−xx−xn+1(1−x)f(x)=x−xn+1(1−x)f′(x)−f(x)=1−(n+1)xn(1−x)f′′(x)−2f′(x)=−(n+1)nxn−1...
所以会发现这个东西是可以递推的,递推的过程是 O ( k ) O(k) O(k)的
当我递推出了这些值之后,直接带入我前一步里面用计算机递推出的式子,就可以求出答案了
算法的总时间复杂度是 O ( k 2 ) O(k^2) O(k2)的
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 110
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define mod 998244353ll
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct EasyMath
{
ll prime[maxn], phi[maxn], mu[maxn];
bool mark[maxn];
ll fastpow(ll a, ll b, ll c)
{
ll t(a%c), ans(1ll);
for(;b;b>>=1,t=t*t%c)if(b&1)ans=ans*t%c;
return ans;
}
void exgcd(ll a, ll b, ll &x, ll &y)
{
if(!b){x=1,y=0;return;}
ll xx, yy;
exgcd(b,a%b,xx,yy);
x=yy, y=xx-a/b*yy;
}
ll inv(ll x, ll p) //p是素数
{return fastpow(x%p,p-2,p);}
ll inv2(ll a, ll p)
{
ll x, y;
exgcd(a,p,x,y);
return (x+p)%p;
}
void shai(ll N)
{
ll i, j;
for(i=2;i<=N;i++)mark[i]=false;
*prime=0;
phi[1]=mu[1]=1;
for(i=2;i<=N;i++)
{
if(!mark[i])prime[++*prime]=i, mu[i]=-1, phi[i]=i-1;
for(j=1;j<=*prime and i*prime[j]<=N;j++)
{
mark[i*prime[j]]=true;
if(i%prime[j]==0)
{
phi[i*prime[j]]=phi[i]*prime[j];
break;
}
mu[i*prime[j]]=-mu[i];
phi[i*prime[j]]=phi[i]*(prime[j]-1);
}
}
}
ll CRT(vector<ll> a, vector<ll> m) //要求模数两两互质
{
ll M=1, ans=0, n=a.size(), i;
for(i=0;i<n;i++)M*=m[i];
for(i=0;i<n;i++)(ans+=a[i]*(M/m[i])%M*inv2(M/m[i],m[i]))%=M;
return ans;
}
}em;
struct C
{
ll a, b;
C(ll a, ll b){this->a=a;this->b=b;}
C(){a=b=0;}
};
C operator+(C x, C y)
{
return C((x.a+y.a)%mod,(x.b+y.b)%mod);
}
C operator-(C x, C y)
{
return C((x.a-y.a)%mod,(x.b-y.b)%mod);
}
C operator*(C x, C y)
{
return C((x.a*y.a+5*x.b*y.b)%mod,(x.a*y.b+x.b*y.a)%mod);
}
C inv(C x)
{
ll t = em.inv(x.a*x.a-x.b*x.b*5,mod);
return C(x.a*t%mod,-x.b*t%mod);
}
C fastpow(C a, ll b)
{
C ans(1,0), t=a;
for(;b;b>>=1,t=t*t)if(b&1)ans=ans*t;
return ans;
}
ll n, k, f[maxn][maxn];
C dao[maxn];
C calc(C b)
{
cl(dao);
ll i, j;
dao[0]=(b-fastpow(b,n+1))*inv(C(1,0)-b);
rep(i,1,k)
{
if(i==1)dao[i] = C(1,0);
ll t=1;
rep(j,1,i)t=t*(n%mod-j+2)%mod;
if(i<=n+1)dao[i] = dao[i] - C(t,0)*fastpow(b,n+1-i);
dao[i] = dao[i] + C(i,0) * dao[i-1];
dao[i] = dao[i] * inv(C(1,0)-b);
}
cl(f);
f[0][0]=1;
rep(i,0,k)
{
rep(j,0,k)
{
(f[i+1][j+1]+=f[i][j])%=mod;
(f[i+1][j]+=j*f[i][j])%=mod;
}
}
C ret;
rep(i,0,k)
{
ret = ret + C(f[k][i],0) * fastpow(b,i) * dao[i];
}
return ret;
}
int main()
{
ll i;
cin >> n >> k;
C b1 = (C(1,0)+C(0,1))*inv(C(2,0)), b2 = (C(1,0)-C(0,1))*inv(C(2,0)), a1, a2, ans;
a1 = inv(C(0,1)), a2=inv(C(0,-1));
ans = a1*calc(b1) + a2*calc(b2);
cout << (ans.a+mod)%mod << endl;
return 0;
}