计算有多少 n ( 1 ≥ n ≥ x ) n(1\ge n\ge x) n(1≥n≥x) 满足
n ⋅ a n ≡ b ( m o d p ) n\cdot a^n\equiv b(mod\ p) n⋅an≡b(mod p)
a , b , p , x ( 2 ≤ p ≤ 1 0 6 + 3 ) ( 1 ≤ a , b < p ) ( 1 ≤ x ≤ 1 0 12 ) a,b,p,x(2 \leq p \leq 10^6+3)(1 \leq a,b < p)(1 \leq x \leq 10^{12}) a,b,p,x(2≤p≤106+3)(1≤a,b<p)(1≤x≤1012), p p p 保证为素数。
看到 p p p 的数据范围很容易联想,要枚举 p p p 范围内的 n n n ,再根据 x x x 范围内的 n n n 和当前 n n n 的关系计算贡献。
我们注意到当 n n n 大于 p p p 的时候会出现 n ⋅ a n ≡ ( n % p ) ⋅ a n % ( p − 1 ) ( m o d p ) n\cdot a^n\equiv (n\%p)\cdot a^{n\%(p-1)}(mod\ p) n⋅an≡(n%p)⋅an%(p−1)(mod p)。
我们令 x = n % ( p − 1 ) , y = n % p x=n\%(p-1),y=n\%p x=n%(p−1),y=n%p,则有 y ⋅ a x ≡ b ( m o d p ) y\cdot a^x\equiv b(mod\ p) y⋅ax≡b(mod p),则有 a x ≡ b ⋅ i n v ( y ) a^x\equiv b\cdot inv(y) ax≡b⋅inv(y),其中 i n v ( y ) inv(y) inv(y) 是关于 p p p 的乘法逆元。
于是我们线性时间算出 1 1 1 到 m i n ( x , p − 1 ) min(x,p-1) min(x,p−1) 的逆元(这里应该用快速幂暴力算时间应该也允许)
记录下所有的 b ⋅ i n v ( y ) b\cdot inv(y) b⋅inv(y) 和它对应的 y y y。
然后我们枚举 x x x,找到最小的 n n n 使得 n % ( p − 1 ) = x , n % p = y n\%(p-1)=x,n\%p=y n%(p−1)=x,n%p=y,也就是中国剩余定理。
然后对于一个小于等于 x x x 的数,它所能产生的贡献即为 ( x − n ) / ( p ⋅ ( p − 1 ) ) + 1 (x-n)/(p\cdot (p-1))+1 (x−n)/(p⋅(p−1))+1。
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
typedef long long LL;
const int maxn = 1e6 + 5;
LL num[maxn];
unordered_map<int, int> id;
LL a[3], b[3];
int n = 2;
LL qmod(LL a, LL b, LL mod){
LL ans = 0;
while(b > 0){
if(b & 1) ans = (ans + a) % mod;
a = (a + a) % mod;
b >>= 1;
}
return ans;
}
LL exgcd(LL a, LL b, LL &x, LL &y){
if(b == 0){
x = 1, y = 0;
return a;
}else{
LL gcd = exgcd(b, a % b, x, y);
LL t = y;
y = x - (a / b) * y;
x = t;
return gcd;
}
}
LL china(){
LL m = a[1], ans = b[1], x, y;
for(int i = 2; i <= n; i++){
LL c = (b[i] - ans % a[i] + a[i]) % a[i];
LL gcd = exgcd(m, a[i], x, y);
if(c % gcd) return -1; // 无解
LL ag = a[i] / gcd;
x = qmod(x, c / gcd, ag);
ans += x * m;
m *= ag;
ans = (ans % m + m) % m;
}
return (ans % m + m) % m;
}
int main(){
IOS;
int aa, bb, p;
LL x, pp;
cin >> aa >> bb >> p >> x;
a[1] = p;
a[2] = p - 1;
pp = (LL)p * (p - 1);
int limit = min(x, p - 1LL);
num[1] = 1;
for(int i = 2; i <= limit; i++)
num[i] = ((p - p / i) * num[p % i]) % p;
for(int i = 1; i <= limit; i++)
id[num[i] * bb % p] = i;
LL ans = 0, now = aa, minone;
for(int i = 1; i <= limit; i++){
if(id.count(now)){
b[1] = id[now];
b[2] = i;
minone = china();
if(x >= minone) ans += (x - minone) / pp + 1;
}
now = now * aa % p;
}
cout << ans << endl;
}