看了题解的推导发现其实并不复杂,但是如果你想要用多项式或者组合数求解的话,就GG了
其实如果把式子列出来的话,不需要怎么推导就能算出来,关键是要想到这个巧妙的式子。
设\(b_i=a_{i+1}-a_{i}(1\leq i\leq k-1)\)
答案就是
\[\sum_{b_1=1}^{m}\sum_{b_2=1}^{m}...\sum_{b_{k-1}=1}^{m}(n-\sum_{i=1}^{k-1}b_i)\]
\[nm^{k-1}-\sum_{i=1}^{k-1}\sum_{b_1=1}^{m}\sum_{b_2=1}^{m}...\sum_{b_{k-1}=1}^{m}b_i\]
\[nm^{k-1}-(k-1)m^{k-2}\sum_{i=1}^{m}i\]
\[nm^{k-1}-(k-1)m^{k-2}\frac{m(m+1)}{2}\]
然后直接算就可以了
这题的关键在于\((k-1)m<n\),它保证了\((n-\sum_{i=1}^{k-1}b_i)\)非负,这样就只需要对每一个序列\(\{b_i\}\)简单地累加贡献就可以了。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll n,k,m,P;
ll p2(ll x){return x*x%P;}
ll pw(ll x,ll p)
{
return p?p2(pw(x,p/2))*(p&1?x:1)%P:1;
}
int main()
{
scanf("%lld%lld%lld%lld",&n,&k,&m,&P);
n%=P;
if(k==1)return printf("%lld\n",n),0;
ll a=n*pw(m,k-1)%P;
ll b=m*(m+1)/2%P*(k-1)%P*pw(m,k-2)%P;
ll ans=(a-b+P)%P;
printf("%lld\n",ans);
return 0;
}