题目
https://ac.nowcoder.com/acm/contest/11260/A
样例
in | out |
---|---|
1 1 1 1 1 1 | 3 |
1 1 4 5 1 4 | 1267 |
8 9 3 8 9 3 | 811685801 |
1000000000 1000000000 1000000000 50 50 1000000000 | 821622942 |
1 1 1 50 50 999999999 | 482598534 |
分析
难点两处,一是类欧几里得算法,二是利用拉格朗日插值求下面这个多项式
F
(
n
,
p
)
=
∑
i
=
0
n
i
p
=
∑
i
=
0
p
+
1
n
i
F(n,p) = \sum _{i=0}^{n} i^p = \sum_{i=0}^{p+1} n^i
F(n,p)=i=0∑nip=i=0∑p+1ni
坑点两处,一是发现Euclid 函数中好多重复计算的参数,于是用了哈希表。二是拉格朗日差值应该至少为 p+q+2 次……
代码
//https://ac.nowcoder.com/acm/contest/11260/A
#include <bits/stdc++.h>
using namespace std;
#define Ha (998244353ll)
long long C[3005][3005];
long long Lag[205][205][205];
long long c[205][205];
//double Lag[105][105][105];
//double c[105][105];
void debug(int);
long long ksm(long long x, long long k)
{
long long ret=1;
for (; k; x=x*x%Ha,k>>=1) if (k&1) ret=ret*x%Ha;
return ret;
}
long long F(long long r, long long p)
{
if (p==0) return r+1;
long long ret=0;
//for (long long i=0; i<=r; i++) ret=(ret+ksm(i,p))%Ha;
for (long long i=0,x=1,mi=p+1; i<=mi; i++,x=x*r%Ha) ret=(ret+c[p][i]*x %Ha)%Ha;
ret=(ret+Ha)%Ha;
return ret;
}
long long G(long long n, long long k, long long p, long long q)
{
long long ret=0;
/*
for (long long i=0; i<=n; i++)
for (long long j=1,mj=k*i; j<=mj; j++)
ret=(ret+ ksm(i,p)*ksm(j,q) %Ha)%Ha;
*/
for (long long t=0,mt=q+1,K=1; t<=mt; t++,K=K*k%Ha) {
ret+= c[q][t]*K %Ha *F(n,p+t) %Ha;
ret%=Ha;
}
return ret;
}
int vis[101][101][101];
int ans[101][101][101];
int Moon=0;
#define LL long long
#define P pair<LL ,LL >
map<P ,int > f[101][101];
const int mod=1e9+7;
int Euclid(long long a, long long b, long long c, long long n, long long p, long long q)
{
LL hash1=(a*23%Ha+b*233%Ha+c*2333%Ha+n*31%Ha);
LL hash2=(a*31%mod+b*13%mod+c*2333%mod+n*31%mod);
P hash=make_pair(hash1,hash2);
if (f[p][q].count(hash))return f[p][q][hash];
if (a*n+b < c) return 0;
// ++Moon;
//if (Moon>=1000000&&a==2&&b==1&&c==1&&n==309&&p==2&&q==5)
// cout <<a <<' ' <<b <<' ' <<c <<' ' <<n <<' ' <<p <<' ' <<q <<'\n';
// if (Moon%10000==0)cout<<Moon<<endl;
long long ret1=0,ret2=0;
if (b>=c) {
long long k=b/c;
ret1= F(n,p)*(F(k,q)-(q==0)) %Ha;
for (long long t=0; t<=q; t++) {
ret2+= C[q][t]*ksm(k,q-t) %Ha *Euclid(a,b%c,c,n,p,t) %Ha;
ret2%=Ha;
}
return f[p][q][hash]=( (ret1+ret2) %Ha+Ha)%Ha;
}
else if (a>=c) {
long long k=a/c;
ret1= G(n,k,p,q);
for (long long t=0; t<=q; t++) {
ret2+= C[q][t]*ksm(k,q-t) %Ha *Euclid(a%c,b,c,n,p+q-t,t) %Ha;
ret2%=Ha;
}
return f[p][q][hash]=( (ret1+ret2) %Ha+Ha)%Ha;
}
else {
long long m=(a*n+b)/c;
ret1= (F(m,q)-(q==0))*(F(n,p)-(p==0)) %Ha;
for (long long t=0; t<=q; t++) {
ret2+= C[q][t]*Euclid(c,c-b-1,a,m-1,t,p) %Ha;
ret2%=Ha;
}
return f[p][q][hash]=( (ret1-ret2) %Ha+Ha)%Ha;
}
}
void solve()
{
long long a,b,c,p,q,n;
cin >>a >>b >>c >>p >>q >>n;
cout <<Euclid(a,b,c,n,p,q) <<'\n';
}
void get_C(long long MAX)
{
for (long long i=0; i<=MAX; i++) C[i][0]=1;
for (long long i=1; i<=MAX; i++)
for (long long j=1; j<=i; j++)
C[i][j]=(C[i-1][j]+C[i-1][j-1])%Ha;
}
void pre() //拉格朗日差值,并预处理出 F(x,p)= sigma(i=1..x, i^p)
{
Lag[0][0][0]=1;
Lag[1][0][0]=1;
Lag[1][0][1]=-1;
Lag[1][1][0]=0;
Lag[1][1][1]=1;
for (int n=2; n<=105; n++) {
for (int i=0; i<n; i++) {
long long rev=ksm(i-n,Ha-2);
Lag[n][i][0]= Lag[n-1][i][0]*(-n) %Ha *rev %Ha;
//Lag[n][i][0]= Lag[n-1][i][0]*(-n) /(i-n);
for (int j=1; j<=n; j++)
Lag[n][i][j]= (Lag[n-1][i][j-1] - Lag[n-1][i][j]*n %Ha) %Ha *rev %Ha;
//Lag[n][i][j]= (Lag[n-1][i][j-1] - Lag[n-1][i][j]*n) /(i-n);
}
Lag[n][n][0]= Lag[n-1][n-1][0]*(1-n) %Ha;
//Lag[n][n][0]= Lag[n-1][n-1][0]*(1-n);
long long rev=ksm(n,Ha-2);
for (int j=1; j<=n; j++)
Lag[n][n][j]= (Lag[n-1][n-1][j-1] - Lag[n-1][n-1][j]*(n-1) %Ha)*rev %Ha;
//Lag[n][n][j]= (Lag[n-1][n-1][j-1] - Lag[n-1][n-1][j]*(n-1))/n;
}
//debug(1);
c[0][0]=0,c[0][1]=1; //F[0](x)=x
c[1][0]=0,c[1][1]=c[1][2]=ksm(2,Ha-2); //F[1](x)=x(x+1)/2
//c[1][0]=0,c[1][1]=c[1][2]=0.5;
for (int p=2; p<=105; p++) {
long long tmp=0;
for(long long i=0; i<=p+1; i++) {
tmp=(tmp+ksm(i,p))%Ha;
for (int j=0; j<=p+1; j++)
c[p][j]=(c[p][j] + tmp*Lag[p+1][i][j] %Ha) %Ha;
//c[p][j]=(c[p][j] + tmp*Lag[p+1][i][j]);
}
}
//debug(2);
}
int main()
{
get_C(2000);
pre();
solve();
return 0;
}
void debug(int opr)
{
while (opr==2) {
int ppp,mmm;
scanf("%d%d",&ppp,&mmm);
for (int i=0; i<=ppp+1; i++) printf("%lld ",c[ppp][i]);
puts("\n");
long long tmp1=0, tmp2=0;
for (long long i=0; i<=mmm; i++) tmp1=(tmp1+ksm(i,ppp))%Ha;
for (long long i=0,x=1; i<=ppp+1; x=x*mmm%Ha,i++) tmp2=(tmp2+c[ppp][i]*x%Ha)%Ha;
printf("[0..%d] ^%d %lld %lld\n",mmm,ppp,tmp1, (tmp2+Ha)%Ha);
}
if (opr==1) {
int nnn;
scanf("%d",&nnn);
vector<double> a(nnn+1),ret(nnn+1);
for (int i=0; i<=nnn; i++) scanf("%lf",&a[i]);
for (int i=0; i<=nnn; i++)
for (int j=0; j<=nnn; j++) {
ret[j]+=Lag[nnn][i][j]*a[i];
}
for (int j=0; j<=nnn; j++)
printf("%f ",ret[j]);
puts("");
for (int i=0; i<=nnn; i++) {
printf("[%2d] ",i);
double tmp=0;
for (long long j=0,x=1; j<=nnn; j++,x=x*i) tmp=(tmp+ret[j]*x);
printf("%f\n",tmp);
}
return;
}
}