题目链接:https://ac.nowcoder.com/acm/contest/205/B?&headNav=www
原来窝被这算法虐过么= =!
这个算法其实是用来解决RSA中计算公钥的问题的,是个底层常数优化算法。。
主要解决两个问题:
-
蒙哥马利约减,即 t p − 1 ( m o d m ) tp^{-1}\ (mod\ m) tp−1 (mod m)
-
蒙哥马利乘模,即 x y ( m o d m ) x y\ (mod\ m) xy (mod m)
这个算法的目的就是避免取模和除法来降低常数
算法过程,直接看杜教板子比看网上的博客来得简单。。
首先预处理几个参数:
b为数制,这里所有的数都以b进制的蒙哥马利表示法表示
p
=
b
k
p=b^k
p=bk且
p
≥
m
p\ge m
p≥m
i
n
v
=
m
−
1
m
o
d
p
inv=m^{-1}\ mod\ p
inv=m−1 mod p
r
2
=
p
2
m
o
d
m
r2=p^2\ mod\ m
r2=p2 mod m
取b进制是为了方便进行位运算,所以b一般取2,然后p取
2
64
2^{64}
264的话取模的时候自然溢出就行了,比较方便
然后关于蒙哥马利表示法,就是x的蒙哥马利表示为
x
p
(
m
o
d
m
)
xp\ (mod\ m)
xp (mod m)
所有数的运算都在此表示法下进行
蒙哥马利约减:
若
t
≤
m
2
t\le m^2
t≤m2,可以用
(
t
−
(
t
∗
i
n
v
m
o
d
p
)
∗
m
)
/
p
(t-(t*inv\ mod\ p)*m )/p
(t−(t∗inv mod p)∗m)/p 代替
t
p
−
1
m
o
d
m
tp^{-1}\ mod \ m
tp−1 mod m
证明:
因为
t
−
t
∗
i
n
v
∗
m
≡
t
+
t
(
−
m
−
1
∗
m
)
≡
0
(
m
o
d
p
)
t-t*inv*m\equiv t+t(-m^{-1}*m)\equiv0\ (mod\ p)
t−t∗inv∗m≡t+t(−m−1∗m)≡0 (mod p)
所以
p
∣
(
t
+
t
∗
i
n
v
∗
m
)
p|(t+t*inv*m)
p∣(t+t∗inv∗m),即
p
∣
(
t
+
(
t
∗
i
n
v
m
o
d
p
)
∗
m
)
p|(t+(t*inv\ mod\ p)*m)
p∣(t+(t∗inv mod p)∗m)
故有
(
t
+
(
t
∗
i
n
v
m
o
d
p
)
∗
m
)
/
p
≡
t
p
−
1
+
(
t
∗
i
n
v
m
o
d
p
)
p
−
1
m
≡
t
p
−
1
(
m
o
d
m
)
(t+(t*inv\ mod\ p)*m )/p\equiv tp^{-1}+(t*inv\ mod \ p)p^{-1}m\equiv tp^{-1}\ (mod \ m)
(t+(t∗inv mod p)∗m)/p≡tp−1+(t∗inv mod p)p−1m≡tp−1 (mod m)
在数值上,由于
(
t
−
(
t
∗
i
n
v
m
o
d
p
∗
m
)
/
p
)
≤
m
2
/
p
≤
m
(t-(t*inv\ mod \ p*m)/p)\le m^2/p\le m
(t−(t∗inv mod p∗m)/p)≤m2/p≤m
(
t
−
(
t
∗
i
n
v
m
o
d
p
∗
m
)
/
p
)
≥
−
m
(t-(t*inv\ mod \ p*m)/p)\ge -m
(t−(t∗inv mod p∗m)/p)≥−m
所以只需判断正负即可
然后就可以用蒙哥马利约减来求出一个数的蒙哥马利表示了,直接对
x
∗
r
2
x*r2
x∗r2进行约减即可
从蒙哥马利表示法中还原出原数也只需做一个约减就可以了
蒙哥马利乘模:
令
x
^
=
x
p
(
m
o
d
m
)
\hat{x}=xp\ (mod\ m)
x^=xp (mod m)
y
^
=
y
p
(
m
o
d
m
)
\hat{y}=yp\ (mod\ m)
y^=yp (mod m)
那么在蒙哥马利表示法中,
x
y
^
=
x
y
p
\hat{xy}=xyp
xy^=xyp
即
x
y
^
≡
x
^
y
^
/
p
(
m
o
d
m
)
\hat{xy} \ \equiv \ \hat{x}\hat{y}/p\ (mod \ m)
xy^ ≡ x^y^/p (mod m)
直接对
x
∗
y
x*y
x∗y进行一次约减就可以了
基本就这么多。。说那么多其实只要有板子就可以。。
这里就贴杜教的板子,窝的板子实在是太丑了。。。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long u64;
typedef __int128_t i128;
typedef __uint128_t u128;
struct Mod64 {
Mod64() :n_(0) {}
Mod64(u64 n) :n_(init(n)) {}
static u64 init(u64 w) { return reduce(u128(w) * r2); }
static void set_mod(u64 m) {
mod = m; assert(mod & 1);
inv = m; for (int i = 0; i < 5; ++i) inv *= 2 - inv * m;
r2 = -u128(m) % m;
}
static u64 reduce(u128 x) {
u64 y = u64(x >> 64) - u64((u128(u64(x)*inv)*mod) >> 64);
return ll(y)<0 ? y + mod : y;
}
Mod64& operator += (Mod64 rhs) { n_ += rhs.n_ - mod; if (ll(n_)<0) n_ += mod; return *this; }
Mod64 operator + (Mod64 rhs) const { return Mod64(*this) += rhs; }
Mod64& operator -= (Mod64 rhs) { n_ -= rhs.n_; if (ll(n_)<0) n_ += mod; return *this; }
Mod64 operator - (Mod64 rhs) const { return Mod64(*this) -= rhs; }
Mod64& operator *= (Mod64 rhs) { n_ = reduce(u128(n_)*rhs.n_); return *this; }
Mod64 operator * (Mod64 rhs) const { return Mod64(*this) *= rhs; }
u64 get() const { return reduce(n_); }
static u64 mod, inv, r2;
u64 n_;
};
u64 Mod64::mod, Mod64::inv, Mod64::r2;
int t, k;
u64 A0, A1, M0, M1, C, M;
void Run()
{
scanf("%d", &t);
while (t--)
{
scanf("%llu%llu%llu%llu%llu%llu%d", &A0, &A1, &M0, &M1, &C, &M, &k);
Mod64::set_mod(M);
Mod64 a0(A0), a1(A1), m0(M0), m1(M1), c(C), ans(1), a2(0);
ans *= a0; ans *= a1;
for (int i = 2; i <= k; ++i)
{
a2 = a1;
a1 = m0 * a1 + m1 * a0 + c;
a0 = a2;
ans *= a1;
}
printf("%llu\n", ans.get());
}
}
int main()
{
Run();
return 0;
}