Law of Commutation HDU - 6189
As we all know, operation ”+” complies with the commutative law. That is, if we arbitrarily select two integers a and b, a+b always equals to b+a. However, as for exponentiation, such law may be wrong. In this problem, let us consider a modular exponentiation. Give an integer m=2n and an integer a, count the number of integers b in the range of [1,m] which satisfy the equation ab≡ba (mod m
).
Input
There are no more than 2500 test cases.
Each test case contains two positive integers n and a seperated by one space in a line.
For all test cases, you can assume that n≤30,1≤a≤109
.
Output
For each test case, output an integer denoting the number of b
.
Sample Input
2 3
2 2
Sample Output
1
2
题意:给定n,a,其中令 2n=m 2 n = m ,现求区间[1,m]内有多少个b满足, ab a b ≡ ≡ ba b a (mod m)。
解释:
1.当a为奇数的时候:
打表可以发现答案横为1,并且这个答案是a=b到时候
具体证明如下
对于a为奇数的情况,b一定为奇数,下证b=a mod 2^n。
由于奇数平方模8余1,故a^b=a mod 8, b^a=b mod 8
故a=b mod 8
由于奇数四次方模16余1,故a^b=a^(b%4) mod 16, b^a=b^(a%4) mod 16
由于b%4=a%4,故a=b mod 16
以此类推,得b=a mod 2^n。解唯一
2.当a为偶数的时候:
首先我们发现
m=2n
m
=
2
n
所以m是偶数,因为a是偶数所以
ab
a
b
是偶数,所以
ab
a
b
%m是偶数
⟶
⟶
ba
b
a
% m是偶数,所以b是偶数。
令a= 2x,
ab
a
b
=
(2x)b
(
2
x
)
b
=
2b⋅xb
2
b
⋅
x
b
此时我们又可以分出两种情况
一、b < n 的时候,因为题目给出到n很小,所以我们根据题意直接快速幂取模暴力算个数
二、b
≥
≥
n 的时候,发现
2b⋅xb
2
b
⋅
x
b
mod
2n
2
n
恰好为0,所以我们只需要使
ba
b
a
mod m为0即可,所以我们到目标转换成求满足
ba
b
a
mod m为0的b的个数
设b=
2x⋅y
2
x
⋅
y
,因为b是从[1,m]中找,所以我们只要找到b到最小值,用m除这个最小值就得到了,所有满足的b到个数,然后再减去[1,n]中到个数,因为已经求过了,所以就得到了b
≥
≥
n情况下b到个数。
因为我们想要求最小到b,且b是偶数,所以我们就可以忽略倍数y到影响,直接考虑
2x
2
x
,要使得
ba
b
a
mod m =
(2x)a
(
2
x
)
a
mod m =
2ax
2
a
x
mod
2n
2
n
≡
≡
0 mod
2n
2
n
,只需要ax
≥
≥
n,所以x
≥
≥
na
n
a
,因此我们就找到了满足条件到最小x值使得b最小,因为要大于等于,所以求
na
n
a
需要向上取整,即只要小数部分不为零,整数部分+1,小数部分变成0,求出b=
2x
2
x
,然后个数=
mb−nb
m
b
−
n
b
.
最后把总个数加起来
code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll q_pow(ll a,ll b,ll mod){
ll ans = 1;
while(b){
if(b & 1)
ans = ans * a % mod;
b >>= 1;
a = a * a % mod;
}
return ans;
}
ll n,a;
int main(){
while(scanf("%lld%lld",&n,&a) != EOF){
if(a & 1){
printf("1\n");
continue;
}
else{
ll m = 1 << n;
ll ans = 0;
for(ll i = 1; i <= n; i++){
if(q_pow(a,i,m) == q_pow(i,a,m))
ans++;
}
ll b2 = n / a;
if(b2 * a < n) b2++;
ll b3 = 1 << b2;
ll res = m / b3 - n / b3;
ans = ans + res;
printf("%lld\n",ans);
}
}
return 0;
}