有时候我们需要按照元素的权重随机选择下一个元素。例如在random walk算法中,我们需要根据边的权重选择下一个节点,边的权重越大,下一个节点被选中的概率也就越大。
假设我们以邻接表的方式保存图,那么选择下一个节点的算法如下所示,这是最朴素的方法。这种方法每次选择下一个节点的时候都要遍历当前节点的所有邻接节点。
double myRandom(){
return random() % 100 / (double)100;
}
int weighted_next(vector<pair<int, double>> neighbors){
int res = neighbors[0].first;
double sum = 0.0;
for (pair<int, double> &neighbor : neighbors)
sum += neighbor.second;
double flag = myRandom() * sum;
double curSum = 0.0;
double preSum;
for (pair<int, double> &neighbor : neighbors) {
preSum = curSum;
curSum += neighbor.second;
if(preSum < flag && flag <= curSum)
res = neighbor.first;
}
return res;
}
本人在node2vec_spark中看到了另外一种实现,这种方法首先为每个节点保存两个数组:J和q,他们的大小的节点的邻居节点数量相同。这种方法的优点是每次选择下一个节点的时候不用遍历所有的邻接节点,理论上如果需要多次根据当前节点选择下一个节点,可以节省时间。其算法如下所示:
pair<vector<int>, vector<double>> init(vector<pair<int, double>> neighbors){
vector<int> J(neighbors.size()); // index
vector<double> q(neighbors.size()); // 类似weight
vector<int> smaller;
vector<int> larger;
double sum = 0.0;
for (pair<int, double> &neighbor : neighbors)
sum += neighbor.second;
for(int i=0; i<neighbors.size(); i++){
q[i] = neighbors[i].second / (sum / neighbors.size());
if(q[i] < 1.0)
smaller.push_back(i);
else
larger.push_back(i);
}
while((!smaller.empty()) && (!larger.empty())){
int small = smaller.back();
int large = larger.back();
smaller.pop_back();
larger.pop_back();
J[small] = large;
q[large] = q[large] + q[small] - 1.0;
if (q[large] < 1.0)
smaller.push_back(large);
else
larger.push_back(large);
}
return make_pair(J, q);
}
int weighted_next(vector<pair<int, double>> neighbors, vector<int> J, vector<double> q){
int length = static_cast<int>(neighbors.size());
int index = static_cast<int>(random() % length);
if(myRandom() >= q[index])
index = J[index];
return neighbors[index].first;
}
实验表明,第二种方法比第一种方法稍微节省一点时间。
int main() {
clock_t startTime,endTime;
vector<pair<int, double>> neis;
neis.reserve(100000);
for(int i=0; i<100000; i++)
neis.emplace_back(i, myRandom() * 1000);
startTime = clock();
for(int i=0; i<1000; i++)
weighted_next(neis);
endTime = clock();
cout << "The run time is: " <<(double)(endTime - startTime) / CLOCKS_PER_SEC << "s" << endl;
startTime = clock();
pair<vector<int>, vector<double>> param = init(neis);
endTime = clock();
cout << "The run time is: " <<(double)(endTime - startTime) / CLOCKS_PER_SEC << "s" << endl;
startTime = clock();
for(int i=0; i<1000; i++)
weighted_next(neis, param.first, param.second);
endTime = clock();
cout << "The run time is: " <<(double)(endTime - startTime) / CLOCKS_PER_SEC << "s" << endl;
return 0;
}
//The run time is: 1.50029s
//The run time is: 0.01662s
//The run time is: 1.36263s