题目链接
这首先是一道高中数学题,我们要由数列的递推公式求出数列的通项公式。
由题目已知:
xi+1=axi+b(mod p)
x
i
+
1
=
a
x
i
+
b
(
m
o
d
p
)
这个可以用高中的待定系数法求出通项公式。我们设
xn+1+k=axn+b+k(mod p)
x
n
+
1
+
k
=
a
x
n
+
b
+
k
(
m
o
d
p
)
xn+1+k=a(xn+k)−(a−1)k+b(mod p)
x
n
+
1
+
k
=
a
(
x
n
+
k
)
−
(
a
−
1
)
k
+
b
(
m
o
d
p
)
我们要使 −(a−1)k+b=0 − ( a − 1 ) k + b = 0 来构造一个等比数列,那么我们解得 k=ba−1 k = b a − 1
xi+ba−1 x i + b a − 1 是等比数列,那么我们得到
xi+ba−1=(x1+ba−1)∗ai−1(mod p)
x
i
+
b
a
−
1
=
(
x
1
+
b
a
−
1
)
∗
a
i
−
1
(
m
o
d
p
)
根据题意,我们要求一个
xi=t
x
i
=
t
时的最小的
i
i
,我们发现上面的式子中除了外都是已知或者可以算出来的,那么我们移项得
ai−1=(xi+b∗inv(a−1))∗inv(x1+b∗inv(a−1))(mod p)
a
i
−
1
=
(
x
i
+
b
∗
i
n
v
(
a
−
1
)
)
∗
i
n
v
(
x
1
+
b
∗
i
n
v
(
a
−
1
)
)
(
m
o
d
p
)
,其中
inv
i
n
v
表示逆元(模意义下的除法转化为乘逆元)。
那么我们把右边的 (xi+b∗inv(a−1))∗inv(x1+b∗inv(a−1)) ( x i + b ∗ i n v ( a − 1 ) ) ∗ i n v ( x 1 + b ∗ i n v ( a − 1 ) ) 看作一个整体 D D ,又因为题目保证是质数,可以发现,这个问题变成了一个BSGS问题。
接下来是代码
#include <bits/stdc++.h>
using namespace std;
int T;
long long a,b,c,X1,t,ni1,ni2;
map <long long,long long> mp;
void exgcd(long long a,long long b,long long &x,long long &y)
{
if(!b)
{
x=1;
y=0;
}
else
{
exgcd(b,a%b,y,x);
y-=a/b*x;
}
}
long long ksm(long long x,long long y,long long mod)
{
long long res=1;
while(y)
{
if(y&1)
res=(res*x)%c;
x=(x*x)%c;
y>>=1;
}
return res;
}
long long gcd(long long x,long long y)
{
return y?gcd(y,x%y):x;
}
int main()
{
scanf("%d",&T);
while(T--)
{
mp.clear();
scanf("%lld%lld%lld%lld%lld",&c,&a,&b,&X1,&t);
long long x=0,y=0;
//坑人的特判
if(t==X1)
{
printf("1\n");
continue;
}
if(a==0)
{
if(t==b)
printf("2\n");
else
printf("-1\n");
continue;
}
if(a==1&&b==0)
{
printf("-1\n");
continue;
}
if(a==1)
{
exgcd(b,c,x,y);
ni1=(x%c+c)%c;
if(!ni1)
ni1+=c;
printf("%lld\n",(((((t-X1)%c)+c)%c)*ni1%c)%c+1);
continue;
}
exgcd(a-1,c,x,y);
ni1=(x%c+c)%c;//inv(a-1)
if(!ni1)
ni1+=c;
exgcd(ni1*b+X1,c,x,y);
ni2=(x%c+c)%c;//inv(x1+b*inv(a-1))
if(!ni2)
ni2+=c;
//xn=t !!! 不是自己求出来再看是不是和t在mod c意义下同余
long long m=(long long)ceil(sqrt(c)),ans,pd=0;
if(gcd(a,c)!=1)
{
printf("-1\n");
continue;
}
for(int j=0;j<=m;++j)
{
if(j==0)
{
ans=(t+b*ni1%c)*ni2%c;
mp[ans]=j;
continue;
}
ans=(ans*a)%c;
mp[ans]=j;
}
ans=1;
x=ksm(a,m,c);
for(int i=1;i<=m;++i)
{
ans=(ans*x)%c;
if(mp[ans])
{
x=i*m-mp[ans];
pd=1;
printf("%lld\n",(x%c+c)%c+1);
break;
}
}
if(!pd)
printf("-1\n");
}
return 0;
}