// return (a * b) % m
LL mod_mult(LL a, LL b, LL m)
{
LL res = 0;
LL exp = a % m;
while (b)
{
if (b & 1)
{
res += exp;
if (res > m) res -= m;
}
exp <<= 1;
if (exp > m) exp -= m;
b >>= 1;
}
return res;
}
求余防止溢出
这个求余乘法的思想是,先将一个数用2进制表示:
bn表示b的二进制的第n个bit,当然,首个比特是从0开始算的。将a乘入括号中,得到:
由于bn要么是0要么是1,所以只需计算为1的部分就可以了,比如3*5:
每加一次就求一次余,这样每次加上去的都是小于m的余数,这样就不怕溢出了。由于每个bit都需要计算一次,所以复杂度是O(log(N))。