Problem Description
baidu熊最近在学习随机算法,于是他决定自己做一个随机数生成器。
这个随机数生成器通过三个参数c,q,n作为种子, 然后它就可以通过以下方式生成伪随机数序列:
m0 = c,
mi+1 = (q2mi + 1) mod 2n, for all i > 0.
因为一些奇怪的原因,q一定是奇数。现在du熊想知道对于一个给定的数x,是不是会出现在这个伪随机数序列里面,如果存在的话,他还想知道最早是在哪里出现,即给定一个整数x,要求找出一个最小的整数k满足mk = x.
Input
输入包含多组数据。
每个测试数据包含一行三个整数: c, q, n, x.
数据满足 0 <= c < 2n, 0 <= q2 < 263, 0 < n <= 63, 0 <= x < 263.
输入以文件结束符结尾。
Output
对应每个测试数据输出满足条件的k,如果x不会出现在序列里面的话,就输出-1。
Sample Input
1 3 3 5
1 3 3 5
1 3 2 5
Sample Output
4
4
-1
二、解题思路:
设X(0)=0, X(i + 1) = (X(i) * q^2 + 1),有X(n) = 1 + q^2 + q^4 + ... + q^(2(n - 1)) (n>=1),即等比数列求和。
因为q是奇数,因此和2^n互质,于是X(n) mod (2^n)的周期就是2^n,即X(0),X(1),...,X(2^n - 1)是0~(2^n-1)的一个排列,因此只要0<=x<2^n就必定有解,否则没有解。且若X(r) mod (2^n)= x, (0<=r<2^n),那么对所有非负整数k有X(k * 2^n + r) mod (2^n)= x,其中r就是最小的一个,对于其它的位置i, X(i) mod (2^n)都不等于x。
用L(x, y)表示x的二进制表示的最低y位,即L(x,y)=x mod (2^y)。
用R(x,y)表示满足X(R(x, y)) mod (2^y) = L(x,y)的最小位置。
那么由上面的过程可知只有R(x,y) + k * 2^y的形式满足X(R(x,y) + k * 2^y) mod (2^y) = L(x,y),显然0<=R(x,y)<2^y
于是我们考虑y+1的情况,由L(x,y)定义可以知道,L(x,y + 1) mod (2 ^ y) = L(x,y)
因此,X(R(x,y+1)) mod (2^y) = X(R(x, y)) mod (2^y),因此R(x, y+1)满足R(x, y+1) = R(x, y) + k * 2^y,而0<=R(x, y+1)<2^(y+1),因此只有两种情况:
R(x,y+1)=R(x,y)或者R(x,y+1)=R(x,y)+2^y
由于X(n)可以通过快速求等比数列的方式快速得到(例如用矩阵乘法),因此我们可以这两种情况都求一下再比较一下就可以得到R(x, y+1)了。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef unsigned long long ULL;
typedef long long LL;
ULL c, n, q, x;
ULL _2n;
int debug = 0;
ULL geti(ULL i, ULL &d){
if(i == 1){
d = q;
return 1;
}
if(i & 1){
ULL dd;
ULL x = geti(i >> 1, dd);
d = dd * dd * q;
return x + x * dd + dd * dd;
} else {
ULL dd;
ULL x = geti(i >> 1, dd);
d = dd * dd;
return x + x * dd;
}
}
ULL findi(ULL i){
if(i == 0) return 0;
ULL d;
return geti(i, d);
}
ULL find(ULL x){
ULL m = 1, k = 0;
for(int i = 1; i <= n; i++){
ULL a = findi(k + m);
if(debug) cout << a % _2n << ", K: " << k << " , M: " << m << endl;
m = m << 1;
if((x % m) == (a % m)) k += m >> 1;
}
return k;
}
void Knuth(){
/* I believe in Knuth */
while(cin >> c >> q >> n >> x){
_2n = 1;
for(int i = 0; i < n; i++) _2n += _2n;
if(debug) cout << "_2n: " << _2n << endl;
q *= q;
q %= _2n;
if(x >= _2n){
cout << "-1" << endl;
continue;
}
ULL kc = find(c);
if(debug == 1) cout << "kc : " << kc << endl;
ULL kx = find(x);
if(debug == 1) cout << "kx : " << kx << endl;
if(debug == 2) {
for(int i = 0; i < _2n; i++) cout << findi(i) % _2n << " ";
cout << endl;
}
if(kc <= kx) cout << kx - kc << endl;
else cout << kx + _2n - kc << endl;
}
}
int main(){
Knuth();
return 0;
}