这种奇葩题肯定没有链接呀
题目描述
出题人记:
题解
设
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示第一次达到
x
=
i
,
y
=
j
x=i,y=j
x=i,y=j 状态的期望步数,那么很容易列出式子:
d
p
[
0
]
[
0
]
=
1
d
p
[
i
]
[
j
]
=
p
∗
d
p
[
i
−
1
m
o
d
n
]
[
j
]
+
q
∗
d
p
[
i
]
[
j
−
1
m
o
d
m
]
+
1
dp[0][0]=1\\ dp[i][j]=p*dp[i-1\bmod n][j]+q*dp[i][j-1\bmod m]+1
dp[0][0]=1dp[i][j]=p∗dp[i−1modn][j]+q∗dp[i][j−1modm]+1
显然转移是有环的,只能高斯消元来求。直接高斯消元的话未知数太多,我们最多只能定
m
m
m 个主元,然后剩下的递推求。
如果我们定的是 x = 0 x=0 x=0 的 m m m 个主元,那么由于剩下部分的递推式处处相同,显然 ( 0 , j ) (0,j) (0,j) 对 ( x , y ) (x,y) (x,y) 处贡献的系数和 ( 0 , j + k m o d m ) (0,j+k\bmod m) (0,j+kmodm) 对 ( x , y + k m o d m ) (x,y+k\bmod m) (x,y+kmodm) 处贡献的系数是一样的,我们只关心主变元相对位置的系数。
系数从一行推到下一行时,下一行内部的转移也是有一个环的。我们把环手解出来,那么就得到了一行的 m m m 个递推系数。
我们发现系数推到下一行时就是当前系数多项式与推第一行时的系数多项式循环卷积的结果,并且这是个线性变换有结合律,所以我们用NTT做循环卷积再套上快速幂即可。
系数和常数可以分开算,总复杂度 O ( m 3 + t m log m log n ) O(m^3+tm\log m\log n) O(m3+tmlogmlogn)。
代码
#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define lll __int128
#define uns unsigned
#define fi first
#define se second
#define IF (it->fi)
#define IS (it->se)
#define END putchar('\n')
#define lowbit(x) ((x)&-(x))
#define inline jzmyyds
using namespace std;
const int MAXN=2005;
const ll INF=1e18;
ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt>0)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
const ll MOD=998244353;
ll ksm(ll a,ll b,ll mo){
ll res=1;
for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
return res;
}
#define g 3ll
int rev[1919],omg[1919];
int NTT(ll*a,int N,int inv){
int n=N,bit=1;
while((1<<bit)<n)bit++;
n=(1<<bit);
for(int i=0;i<n;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
if(i<rev[i])swap(a[i],a[rev[i]]);
}ll x,y,tmp;omg[0]=1;
for(int m=1,mi=(MOD+1)>>1;m<n;m<<=1,mi>>=1){
tmp=ksm(g,inv>0?mi:MOD-1-mi,MOD);
for(int i=1;i<m;i++)omg[i]=omg[i-1]*tmp%MOD;
for(int i=0,om=0;i<n;i+=(m<<1),om=0)
for(int j=i;j<i+m;j++,om++)
x=a[j],y=a[j+m]*omg[om]%MOD,a[j]=(x+y)%MOD,a[j+m]=(x-y+MOD)%MOD;
}if(inv<0)for(int i=0,iv=ksm(n,MOD-2,MOD);i<n;i++)(a[i]*=iv)%=MOD;
return n;
}
#undef g
ll c[405][405],X[405];
#define x X
void Gauss(int n){
for(int row=0,col=0;row<n&&col<n;row++,col++){
if(!c[row][col]){
for(int i=row+1;i<n;i++)if(c[i][col]){
for(int j=col;j<=n;j++)swap(c[i][j],c[row][j]);
break;
}
if(!c[row][col]){row--;continue;}
}
const ll g=MOD-ksm(c[row][col],MOD-2,MOD);
for(int i=row+1;i<n;i++)if(c[i][col]){
const ll cg=c[i][col]*g%MOD;
for(int j=col;j<=n;j++)(c[i][j]+=c[row][j]*cg)%=MOD;
}
}
for(int i=n-1;i>=0;i--){
for(int j=i+1;j<n;j++)(c[i][n]+=MOD-c[i][j]*x[j]%MOD)%=MOD;
x[i]=c[i][n]*ksm(c[i][i],MOD-2,MOD)%MOD;
}
}
#undef x
int n,m,k=1;
ll p,q,iv,f[35][1919],g[1919],h[1919],iq,cq;
using pll=pair<ll,ll>;
ll getC(int x){
pll a=pll(0,0),b=pll(p*cq%MOD,cq);
for(;x;x>>=1,b.se=(b.se*b.fi+b.se)%MOD,(b.fi*=b.fi)%=MOD)
if(x&1)(a.fi*=b.fi)%=MOD,a.se=(a.se*b.fi+b.se)%MOD;
return a.se;
}
int main()
{
freopen("huawei.in","r",stdin);
freopen("huawei.out","w",stdout);
n=read(),m=read(),p=read(),q=read(),iv=ksm(p+q,MOD-2,MOD);
(p*=iv)%=MOD,(q*=iv)%=MOD,cq=ksm(MOD+1-q,MOD-2,MOD);
iq=ksm(MOD+1-ksm(q,m,MOD),MOD-2,MOD);
for(int i=0;i<m;i++)f[0][i]=p*ksm(q,i,MOD)%MOD*iq%MOD;
while(k<(m<<1))k<<=1;
for(int i=1;i<=30;i++){
for(int j=0;j<m;j++)g[j]=f[i-1][j];
for(int j=m;j<k;j++)g[j]=0;
NTT(g,k,1);
for(int j=0;j<k;j++)(g[j]*=g[j])%=MOD;
NTT(g,k,-1);
for(int j=0;j<k;j++)(f[i][j%m]+=g[j])%=MOD;
}
h[0]=1;
for(int i=0;i<=30;i++)if((n-1)&(1<<i)){
for(int j=0;j<m;j++)g[j]=f[i][j];
for(int j=m;j<k;j++)g[j]=0;
NTT(g,k,1),NTT(h,k,1);
for(int j=0;j<k;j++)(g[j]*=h[j])%=MOD,h[j]=0;
NTT(g,k,-1);
for(int j=0;j<k;j++)(h[j%m]+=g[j])%=MOD;
}
ll nc=getC(n-1);
c[0][0]=1;
for(int i=1;i<m;i++){
c[i][i]=1,c[i][m]=(nc*p+1)%MOD,c[i][i-1]=MOD-q;
for(int j=0;j<m;j++)
(c[i][j]+=MOD-h[(i-j+m)%m]*p%MOD)%=MOD;
}
Gauss(m);
for(int Q=read();Q--;){
int x=read(),y=read();
for(int i=0;i<m;i++)h[i]=X[i];
for(int i=m;i<k;i++)h[i]=0;
for(int i=0;i<=30;i++)if(x&(1<<i)){
for(int j=0;j<m;j++)g[j]=f[i][j];
for(int j=m;j<k;j++)g[j]=0;
NTT(g,k,1),NTT(h,k,1);
for(int j=0;j<k;j++)(g[j]*=h[j])%=MOD,h[j]=0;
NTT(g,k,-1);
for(int j=0;j<k;j++)(h[j%m]+=g[j])%=MOD;
}
print((h[y]+getC(x))%MOD);
}
return 0;
}