zoj 3254
传送门 : ZOJ 3254
题解 : 扩展BSGS + 判断循环节
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = (1 << 16) + 10;
ll A, P, D;
ull M;/**注意答案范围*/
struct Bj{
ll v;
int id;
bool operator < (const Bj &rhs) const{
return v == rhs.v ? id < rhs.id : v < rhs.v;
}
}R[N];
ll FastPowMod(ll a, ll b, ll p){
ll ret = 1 % p;
while(b){
if(b & 1) ret = ret * a % p;
a = a * a % p;
b >>= 1;
}
return ret;
}
int findL(ll v, int r){
int l = 0, h = r - 1;
while(l <= h){
int m = (l + h) >> 1;
if(R[m].v == v) return R[m].id;
if(R[m].v > v) h = m - 1;
else l = m + 1;
}
return -1;
}
void extend_gcd(ll a, ll b, ll &x, ll &y){
if(b == 0){
x = 1;
y = 0;
return;
}
extend_gcd(b, a % b, y, x);
y -= a / b * x;
}
ll inv (ll a, ll b, ll p){
ll x, y;
extend_gcd(a, p, x, y);
return ((b * x) % p + p) % p;
}
ll BSGS(ll b, ll p, ll n, ll phi){/**BSGS*/
int m = ceil(sqrt((double)p));
for(int i = 0; i < m; ++i) {
R[i].id = i;
R[i].v = i == 0 ? 1 % p : b * R[i - 1].v % p;
}
sort(R, R + m);
int cnt = 1;
for(int i = 1; i < m; ++i) if(R[i].v != R[i - 1].v) R[cnt++] = R[i];
ll tmp = n;
ll bm = inv(FastPowMod(b, m, p), 1, p);
for(int i = 0; i < m; ++i){
int pos = findL(tmp, cnt);/**二分*/
if(~pos) return i * m + pos;
tmp = tmp * bm % p;
}
return M + 1;
}
ll gcd(ll a, ll b) {return b == 0 ? a : gcd(b, a % b);}
ll getPhi(ll x){/**欧拉*/
ll ret = x;
for(ll i = 2; i * i <= x; ++i){
if(x % i == 0){
while(x % i == 0) x /= i;
ret = ret / i * (i - 1);
}
}
if(x > 1) ret = ret / x * (x - 1);
return ret;
}
ll cal(ll b, ll p, ll phi){/**循环节*/
ll m = phi;
for(ll x = 2; x * x <= phi; ++x){
if(phi % x == 0){
//ll tmp = 1;
while(phi % x == 0) phi /= x;
while(m % x == 0 && FastPowMod(b, m / x, p) == 1) m /= x;
/**枚举phi的质因子(循环节必定是phi的约数)*/
}
}
if(phi > 1) {
while( m % phi == 0 && FastPowMod(b, m / phi, p) == 1) m /= phi;
}
return m;
}
ull solve(ll b, ll p, ll n){
ll d;
ll cnt = 0, ret = 0, tmp = 1 % p, tx = 1 % p, tn = n, tp = p;
ll Ret = 1 % p;
while((d = gcd(b, p)) != 1){
if(tx == tn) return 1;/**去约数完成之前找到*/
if(n % d){
return 0;
}
n /= d;
p /= d;
tmp = b / d * tmp % p;
tx = tx * b % tp;
++cnt;
}
ll phi = getPhi(p);
ll R = inv(tmp, n, p); /**逆元*/ inv(tmp, n, p);
ll pos = BSGS(b, p, R, phi);
if((pos += cnt) <= M){
++ret;
ret += (M - pos) / cal(b, p, phi);
}
return ret;
}
int main(){
//freopen("in.txt", "r", stdin);
while(~scanf("%lld%lld%lld%llu", &A, &P, &D, &M)){
ull ans = solve(A % P, P, D);
printf("%llu\n", ans);
}
return 0;
}