传送门
题解:
其实和这道题差不多:https://blog.csdn.net/zxyoi_dreamer/article/details/89048235
只是看到有一道类似的题就来重新写了一遍,这篇的代码要好看一点。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const
using std::cerr;
using std::cout;
cs int mod=1e9+7;
inline int add(int a,int b){a+=b-mod;return a+(a>>31&mod);}
inline int dec(int a,int b){a-=b;return a+(a>>31&mod);}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline int power(int a,int b,int res=1){
for(;b;b>>=1,a=mul(a,a))(b&1)&&(res=mul(res,a));
return res;
}
cs int N=5e5+7;
int n,m;
int p[N],pc;
bool mark[N];
int x[N],f[N],fac[N],ifac[N];
inline void init(int lim=m+1){
x[1]=1;
for(int re i=2;i<=lim;++i){
if(!mark[i])p[++pc]=i,x[i]=power(i,m);
for(int re j=1;i*p[j]<=lim;++j){
mark[i*p[j]]=true;
x[i*p[j]]=mul(x[i],x[p[j]]);
if(i%p[j]==0)break;
}
}
fac[0]=fac[1]=1;
for(int re i=2;i<=lim;++i)fac[i]=mul(fac[i-1],i);
ifac[lim]=power(fac[lim],mod-2);
for(int re i=lim-1;~i;--i)ifac[i]=mul(ifac[i+1],i+1);
}
inline int C(int n,int m){return mul(fac[n],mul(ifac[m],ifac[n-m]));}
int pre[N],suf[N];
inline int calc(int n,int x){
if(x<=n)return f[x];
pre[0]=suf[n+1]=1;int ans=0;
for(int re i=1;i<=n;++i)pre[i]=mul(pre[i-1],x-i);
for(int re i=n;i>=1;--i)suf[i]=mul(suf[i+1],x-i);
for(int re i=1;i<=n;++i){
int coef=mul(mul(pre[i-1],suf[i+1]),mul(ifac[i-1],ifac[n-i]));
if((n^i)&1)coef=mod-coef;
Inc(ans,mul(coef,f[i]));
}
return ans;
}
signed main(){
#ifdef zxyoi
freopen("adventrue.in","r",stdin);
#endif
scanf("%d%d",&n,&m);if(m==1)printf("%lld",(ll)n*(n+1)/2%mod),exit(0);
init();
if(n<=m+1){
int ans=0,now=1;
for(int re i=1;i<=n;++i){
Mul(now,m);
Inc(ans,mul(now,x[i]));
}
cout<<ans<<"\n";
exit(0);
}
int inv=power(m,mod-2),a=1,b=0;
pre[0]=1,suf[0]=0;
for(int re i=1;i<=m+1;++i){
pre[i]=mul(pre[i-1],inv);
suf[i]=mul(add(suf[i-1],x[i]),inv);
int coef=C(m+1,i);if(i&1)coef=mod-coef;
Inc(a,mul(coef,pre[i]));
Inc(b,mul(coef,suf[i]));
}
int x=power(a,mod-2,dec(0,b));
for(int re i=0;i<=m+1;++i)f[i]=add(mul(pre[i],x),suf[i]);
int ans=calc(m+1,n);
ans=dec(mul(ans,power(m,n)),f[0]);
cout<<mul(ans,m)<<"\n";
return 0;
}