题目
题目描述
求这个关于
n
n
n 的同余方程的不超过
x
x
x 的正整数解:
n
⋅
a
n
≡
b
(
m
o
d
p
)
n\cdot a^n\equiv b\pmod{p}
n⋅an≡b(modp)
数据范围与提示
2
≤
p
≤
1
0
6
+
3
2≤p≤10^6+3
2≤p≤106+3 且
1
≤
a
,
b
<
p
1≤a,b<p
1≤a,b<p 且
1
≤
x
≤
1
0
12
1≤x≤10^{12}
1≤x≤1012 。
思路
利用类似大步小步的方法。
设 n = q m − r ( 0 ≤ r < m ) n=qm-r(0\le r<m) n=qm−r(0≤r<m)
代入原式得到 ( q m − r ) a q m − r ≡ b ( m o d p ) (qm-r)a^{qm-r}\equiv b\pmod p (qm−r)aqm−r≡b(modp)
即 ( q m − r ) a q m ≡ b ⋅ a r ( m o d p ) (qm-r)a^{qm}\equiv b\cdot a^r\pmod p (qm−r)aqm≡b⋅ar(modp)
似乎项数太多,做不了?不妨规定 m ∣ p m|p m∣p ,重新审视式子,得到 − r ⋅ a q m ≡ b ⋅ a r ( m o d p ) -r\cdot a^{qm}\equiv b\cdot a^r\pmod p −r⋅aqm≡b⋅ar(modp)
于是可以拿到经典的大步小步式子 a q m ≡ − r − 1 a r ⋅ b ( m o d p ) a^{qm}\equiv -r^{-1}a^r\cdot b\pmod p aqm≡−r−1ar⋅b(modp)
为了满足 O ( m ) = O ( x ) \mathcal O(m)=\mathcal O(\sqrt x) O(m)=O(x) 且 m ∣ p m|p m∣p ,可以令 m = p ⌊ x p ⌋ m=p\lfloor\frac{\sqrt x}{p}\rfloor m=p⌊px⌋ ,时间复杂度就是 O ( x ) \mathcal O(\sqrt{x}) O(x) 的。
代码
极致的压行,给你极致的阅读体验。
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
using namespace std;
inline long long readint(){
long long a = 0; char c = getchar(), f = 1;
for(; c<'0' or c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c and c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
inline void writeint(long long x){
if(x > 9) writeint(x/10);
putchar((x%10)^48);
}
# define MB template < class T >
MB void getMax(T &a,const T &b){ if(a < b) a = b; }
MB void getMin(T &a,const T &b){ if(b < a) a = b; }
const int MaxP = 1000005;
int cnt[MaxP], inv[MaxP]={1,1};
int main(){
int a = readint(), b = readint(), p = readint();
int rp = p; /* real p */ p = p*(MaxP/p);
long long x = readint(); // 会爆int
if(x%p == 0) -- x; // 下面就可以省去取模
for(int i=(inv[1]=1)+1; i<rp; ++i)
inv[i] = (0ll+rp-rp/i)*inv[rp%i]%rp;
int a_p = 1; /* a的p次方 */ long long ans = 0;
for(int i=1; i<=p; ++i) a_p = 1ll*a_p*a%rp;
/* 第一次计算 */
for(int r=1,now=rp-b; r<p; ++r) now = 1ll*now*a%rp,
((r%rp != 0) && ++ cnt[1ll*now*inv[r%rp]%rp]);
for(int m=1,now=1; 1ll*m*p-(p-1)<=x; ++m)
now = 1ll*now*a_p%rp, ans += cnt[now];
/* 去掉不合法 */
for(int i=0; i<MaxP; ++i) cnt[i] = 0; // 清空
for(int r=1,now=rp-b; r<p-(x%p); ++r) now = 1ll*now*a%rp,
((r%rp != 0) && ++ cnt[1ll*now*inv[r%rp]%rp]);
for(int m=1,now=1; true; ++m)
if(1ll*m*p-(p-1) > x){
ans -= cnt[now]; break;
} else now = 1ll*now*a_p%rp;
printf("%lld\n",ans);
return 0;
}