测试地址:The Sum of the k-th Powers
题目大意: 求
∑
i
=
1
n
i
k
\sum_{i=1}^ni^k
∑i=1nik对
1
0
9
+
7
10^9+7
109+7取模的值。
做法: 本题需要用到拉格朗日插值。
容易看出(或者用数学归纳法简单证明),答案
f
(
n
)
f(n)
f(n)是一个关于
n
n
n的最高次为
k
+
1
k+1
k+1的多项式。问题是,怎么得到这个多项式呢?这时候就要使用拉格朗日插值法。
拉格朗日插值法是一个可以由
k
+
1
k+1
k+1个二维平面上的点
(
x
i
,
y
i
)
(x_i,y_i)
(xi,yi),构造出一个正好穿过这些点的
k
k
k次函数
f
(
x
)
f(x)
f(x)的算法。有:
f
(
x
)
=
∑
i
=
1
k
+
1
y
i
⋅
l
i
(
x
)
f(x)=\sum_{i=1}^{k+1}y_i\cdot l_i(x)
f(x)=∑i=1k+1yi⋅li(x)
其中
l
i
(
x
)
l_i(x)
li(x)称为插值基函数,其表达式为
∏
i
≠
j
x
−
x
j
x
i
−
x
j
\prod_{i\ne j}\frac{x-x_j}{x_i-x_j}
∏i̸=jxi−xjx−xj。
为什么有这个式子呢?首先看存在性,注意到
l
i
(
x
)
l_i(x)
li(x)当且仅当
x
=
x
i
x=x_i
x=xi时有取值
1
1
1,否则当
x
=
x
j
(
i
≠
j
)
x=x_j(i\ne j)
x=xj(i̸=j)时取值为
0
0
0,那么显然上面的函数可以穿过对应的点。至于唯一性好像要用一些玄学的东西证,你只需要知道
k
+
1
k+1
k+1个点一定能确定一个
k
k
k次函数就行了…
回到这一题,我们显然可以用
(
i
,
∑
j
=
1
i
j
k
)
(
0
≤
i
≤
k
+
1
)
(i,\sum_{j=1}^ij^k)(0\le i\le k+1)
(i,∑j=1ijk)(0≤i≤k+1)这些点来使用拉格朗日插值法,算出这些点显然是
O
(
k
log
k
)
O(k\log k)
O(klogk)的(或者你可以用线性筛优化到
O
(
k
)
O(k)
O(k)…)。接下来我们把
n
n
n代入上式:
f
(
n
)
=
∑
i
=
0
k
−
1
y
i
⋅
∏
i
≠
j
n
−
j
i
−
j
f(n)=\sum_{i=0}^{k-1}y_i\cdot \prod_{i\ne j}\frac{n-j}{i-j}
f(n)=∑i=0k−1yi⋅∏i̸=ji−jn−j
后面的积式中,分子和分母可以分别预处理,分子只要处理出前缀积和后缀积即可(不能用
1
n
−
i
⋅
∏
(
n
−
j
)
\frac{1}{n-i}\cdot \prod (n-j)
n−i1⋅∏(n−j)这个式子,因为
n
−
i
n-i
n−i有可能为
0
0
0),而对于分母,我们发现分母可以拆成两个类似于下面这样的部分:
∏
i
=
1
x
1
i
\prod_{i=1}^x\frac{1}{i}
∏i=1xi1和
∏
i
=
1
x
1
(
−
i
)
\prod_{i=1}^x\frac{1}{(-i)}
∏i=1x(−i)1,分别预处理出来即可,时间复杂度为
O
(
k
log
k
)
O(k\log k)
O(klogk)(或者用线性求逆元优化成
O
(
k
)
O(k)
O(k))。
所以我们就解决了此题,时间复杂度为
O
(
k
log
k
)
O(k\log k)
O(klogk)(可以优化到
O
(
k
)
O(k)
O(k))。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
ll n,k,invp[1000010],invn[1000010],inv[1000010];
ll pre[1000010],suf[1000010];
ll power(ll a,ll b)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=s*ss%mod;
ss=ss*ss%mod;
b>>=1;
}
return s;
}
int main()
{
scanf("%lld%lld",&n,&k);
invp[0]=invn[0]=1;
pre[0]=n;
for(ll i=1;i<=k+1;i++)
{
invp[i]=invp[i-1]*power(i,mod-2)%mod;
invn[i]=invn[i-1]*power(mod-i,mod-2)%mod;
pre[i]=pre[i-1]*(n-i+mod)%mod;
}
suf[k+1]=(n-k-1+mod)%mod;
for(int i=k;i>=0;i--)
suf[i]=suf[i+1]*(n-i+mod)%mod;
ll ans=0,now=0;
for(int i=0;i<=k+1;i++)
{
if (k>0||i>0) now=(now+power(i,k))%mod;
ll tot=1;
if (i>0) tot=tot*pre[i-1]%mod;
if (i<k+1) tot=tot*suf[i+1]%mod;
ll nowtot=tot*invp[i]%mod*invn[k-i+1]%mod;
ans=(ans+now*nowtot)%mod;
}
printf("%lld",ans);
return 0;
}