RMQ的FischerHeun算法实现(代码可能有错)

RMQ的FischerHeun算法是构思非常巧妙的RMQ算法的变种,该算法的核心思想是在利用ballot number的模拟笛卡尔树的构建过程中计算每一棵笛卡尔数对应的类型,然后在对查询序列预处理的过程中计算具有实际数据的特定segment对应的笛卡尔树的类型,最后在RMQ查询过程中利用segment对应的类型找到对应笛卡尔树,利用其最近公共祖先完成RMQ查询。这个算法的精华之处在于笛卡尔树类型的计算,原论文中代码不算复杂但是原理比较难以理解,我在实现这个算法的过程中半懂不懂,多次尝试后总算完成了代码编写.但代码可能存在未知的错误和bug,目前只对长为256的查询序列(segment长度为3)通过了测试用例,更长的segment由于数据量过大没有测试,发现错误请告知,我会尽可能修改正确
具体参见论文Theoretical and Practical Improvements on the RMQ-Problem,with Applications to LCA and LCE
作者 Johannes Fischer and Volker Heun
C++代码:

#include <iostream>
#include <vector>
#include <stack>
#include <random>
using namespace std;

size_t log2(size_t N)
{
	size_t k = 0;
	while (1ull << k <= N)
	{
		++k;
	}
	--k;
	return k;
}

template <typename T>
struct m_info
{
	T m_value;
	size_t index;
	m_info(const T& m, size_t i) :m_value(m), index(i) {}
	m_info() = default;
	bool operator == (const m_info& be_compared)const { return m_value == be_compared.m_value && index == be_compared.index; }
};

template <typename T>
const m_info<T>& min_m_info(const m_info<T>& left, const m_info<T>& right)
{
	if (left.m_value <= right.m_value)
	{
		return left;
	}
	return right;
}

template <typename T>
void SparseTable(vector<vector<m_info<T>>>& d, vector<T>& seq, size_t left, size_t right)  //[left, right]从0开始
{
	for (size_t i = 0; i <= right - left; ++i)
	{
		d[i][0] = m_info<T>(seq[i + left], i);
	}

	for (size_t i = 1; (1 << i) <= right - left + 1; ++i)
	{
		for (size_t j = 0; j <= right - left + 1 - (1 << i); ++j)
		{
			if (d[j][i - 1].m_value <= d[j + (1 << i - 1)][i - 1].m_value)
			{
				d[j][i] = d[j][i - 1];
			}
			else
			{
				d[j][i] = d[j + (1 << i - 1)][i - 1];
			}
		}
	}
}

template <typename T>
m_info<T> query(size_t i, size_t j, vector<vector<m_info<T>>>& d)  i,j从0开始算
{
	size_t l = j - i + 1;
	size_t k = log2(l);
	if (d[i][k].m_value <= d[j - (1 << k) + 1][k].m_value)
		return d[i][k];
	return d[j - (1 << k) + 1][k];
}

template <typename T>
m_info<T> query(size_t i, size_t j, vector<vector<size_t>>& d, size_t offset, vector<T> &value_seq)    /i, j从0开始
{
	return m_info<T>(value_seq[offset + d[i][j] - 1], d[i][j] - 1);
}

void computeBallotNumber(vector<vector<size_t>>& ballot_num)
{
	
	for (size_t i = 0; i < ballot_num.size(); ++i)
	{
		ballot_num[0][i] = 1;
	}
	for (size_t i = 1; i < ballot_num.size(); ++i)
	{
		for (size_t j = 1; j <= i; ++j)
		{
			ballot_num[j][i] = ballot_num[j][i - 1] + ballot_num[j - 1][i];
		}
	}
}

enum class StackOperateType { PUSH, POP };
struct StackOperationSeqNode
{
	StackOperateType op_type;
	size_t be_operated_index;
	StackOperationSeqNode(StackOperateType o, size_t b) :op_type(o), be_operated_index(b) {}
	StackOperationSeqNode() = default;
};

template <typename T>
class FischerHeunRMQ
{
public:
	m_info<T> queryRMQ(size_t i, size_t j);
	FischerHeunRMQ(vector<T>& v, const T& m);
private:
	T max_value_for_type;
	size_t N;
	size_t segment_length;
	size_t block_num;
	vector<T> value_seq;
	vector<size_t> block_type;
	vector<vector<vector<size_t>>> group_in_query_map_table;
	vector<vector<m_info<T>>> group_between_query_table;
	vector<size_t> min_in_block_index;
	void preProcess(vector<vector<size_t>>& ballot_num);
	size_t computeTypeForEveryBlock(size_t left_start, vector<vector<size_t>>& ballot_num);
	void enumerateBlockTypePreprocessRMQ(size_t i, size_t q, stack<size_t>& work_stack, vector<StackOperationSeqNode>& op_seq, size_t op_seq_index, size_t N, vector<vector<size_t>>& ballot_num, vector<vector<vector<size_t>>>& group_in_query_map_table);
	void computeRMQForType(const vector<StackOperationSeqNode>& op_seq, size_t op_seq_index, vector<vector<size_t>>& LCA);
};

template <typename T>
void FischerHeunRMQ<T>::preProcess(vector<vector<size_t>>& ballot_num)   //
{
	vector<T> m_value_for_every_block(block_num);
	size_t start_index_for_left_bound = 0;
	for (size_t i = 0; i < block_num; ++i)
	{
		block_type[i] = computeTypeForEveryBlock(start_index_for_left_bound, ballot_num);
		m_info<T> t = query(0, segment_length - 1, group_in_query_map_table[block_type[i]], start_index_for_left_bound, value_seq);
		start_index_for_left_bound = start_index_for_left_bound + segment_length;
		m_value_for_every_block[i] = t.m_value;
		min_in_block_index[i] = t.index;
	}
	SparseTable(group_between_query_table, m_value_for_every_block, 0, m_value_for_every_block.size() - 1);
}

template <typename T>
FischerHeunRMQ<T>::FischerHeunRMQ(vector<T>& v, const T& m) :value_seq(v), N(v.size()), segment_length(log2(N)/4 + 1), block_num(N% segment_length != 0 ? N / segment_length + 1 : N / segment_length),
group_between_query_table(block_num, vector<m_info<T>>(log2(block_num) + 1)), max_value_for_type(m), block_type(block_num), min_in_block_index(block_num)   
{
	if (value_seq.empty())
	{
		cout << "ERROR输入序列为空!" << endl;
		exit(-1);
	}

	size_t r;
	if ((r = N % segment_length) != 0)
	{
		r = segment_length - r;
		for (size_t i = 1; i <= r; ++i)
		{
			value_seq.push_back(m);
		}
	}

	vector<vector<size_t>> ballot_num(segment_length + 1, vector<size_t>(segment_length + 1, 0));
	computeBallotNumber(ballot_num);

	new (&group_in_query_map_table) vector<vector<vector<size_t>>>(ballot_num[segment_length][segment_length] + 1, vector<vector<size_t>>(segment_length, vector<size_t>(segment_length)));
	for (size_t i = 0; i < group_in_query_map_table.size(); ++i)
	{
		for (size_t j = 0; j < segment_length; ++j)
		{
			group_in_query_map_table[i][j][j] = j + 1;
		}
	}
	if (segment_length > 1)
	{
		stack<size_t> work_stack;
		work_stack.push(1);
		vector<StackOperationSeqNode> op_seq(2 * segment_length - 1);  //改
		enumerateBlockTypePreprocessRMQ(2, segment_length, work_stack, op_seq, 0, 0, ballot_num, group_in_query_map_table);
	}
	   
	preProcess(ballot_num);
}

template <typename T>
size_t FischerHeunRMQ<T>::computeTypeForEveryBlock(size_t left_start, vector<vector<size_t>>& ballot_num)
{
	vector<T> rp(segment_length + 2);
	size_t q = segment_length;
	size_t type_value = 0;

	for (size_t i = 1; i <= segment_length; ++i)
	{
		while (q + i - segment_length != 1 && rp[q + i - segment_length] > value_seq[left_start + i - 1])
		{
			type_value += ballot_num[segment_length - i][q];
			--q;
		}
		rp[q + i - segment_length + 1] = value_seq[left_start + i - 1];
	}
	return type_value;
}

template <typename T>
void FischerHeunRMQ<T>::computeRMQForType(const vector<StackOperationSeqNode>& op_seq, size_t op_seq_index, vector<vector<size_t>> &LCA)  //对角线初始化
{
	vector<bool> in_stack(segment_length, false);
	stack<size_t> work_stack;
	in_stack[0] = true;  //是否必为1
	work_stack.push(0);
	work_stack.push(1);
	for (size_t i = 0; i <= op_seq_index; ++i)     //i应从1开始,0需预处理 
	{
		if (op_seq[i].op_type == StackOperateType::PUSH)
		{
			in_stack[op_seq[i].be_operated_index - 1] = true;

			for (size_t j = work_stack.top() + 1; j < op_seq[i].be_operated_index; ++j)
			{
				LCA[j - 1][op_seq[i].be_operated_index - 1] = op_seq[i].be_operated_index;
			}

			size_t k;
			for (size_t j = work_stack.top(); j >= 1; --j)
			{
				if (in_stack[j - 1])
				{
					k = j;
					LCA[j - 1][op_seq[i].be_operated_index - 1] = j;
				}
				else
				{
					LCA[j - 1][op_seq[i].be_operated_index - 1] = k;
				}
			}
			work_stack.push(op_seq[i].be_operated_index);
		}
		else
		{
			in_stack[op_seq[i].be_operated_index - 1] = false;
			work_stack.pop();
		}
	}
}

template<typename T>
void FischerHeunRMQ<T>::enumerateBlockTypePreprocessRMQ(size_t i, size_t q, stack<size_t> &work_stack, vector<StackOperationSeqNode> &op_seq, size_t op_seq_index, size_t N, vector<vector<size_t>> &ballot_num, vector<vector<vector<size_t>>> &group_in_query_map_table)
{  //注意第一层递归调用
	work_stack.push(i);
	op_seq[op_seq_index] = StackOperationSeqNode(StackOperateType::PUSH, i);
	if (i == segment_length)
	{
		computeRMQForType(op_seq, op_seq_index, group_in_query_map_table[N]);
	}
	else
	{
		enumerateBlockTypePreprocessRMQ(i + 1, q, work_stack, op_seq, op_seq_index + 1, N, ballot_num, group_in_query_map_table);
	}
	work_stack.pop();

	if (work_stack.empty() == false)
	{
		op_seq[op_seq_index] = StackOperationSeqNode(StackOperateType::POP, work_stack.top());
		work_stack.pop();
		enumerateBlockTypePreprocessRMQ(i, q - 1, work_stack, op_seq, op_seq_index + 1, N + ballot_num[segment_length - i][q], ballot_num, group_in_query_map_table);
		work_stack.push(op_seq[op_seq_index].be_operated_index);
	}
}

template <typename T>
m_info<T> FischerHeunRMQ<T>::queryRMQ(size_t i, size_t j)
{
	if (i > j || j > N)
	{
		cout << "非法的查询索引!" << endl;
		exit(-1);
	}

	size_t i_block_index = i / segment_length;
	size_t i_in_block_index = i % segment_length;
	size_t j_block_index = j / segment_length;
	size_t j_in_block_index = j % segment_length;
	m_info<T> result;
	size_t first_value_in_block_offset;
	if (i_block_index == j_block_index)
	{
		first_value_in_block_offset = i_block_index * segment_length;
		result = query(i_in_block_index, j_in_block_index, group_in_query_map_table[block_type[i_block_index]], first_value_in_block_offset, value_seq);
	}
	else
	{
		if (i_in_block_index == 0 && j_in_block_index == segment_length - 1)
		{
			result = query(i_block_index, j_block_index, group_between_query_table);
			size_t block_index = result.index;
			result = m_info<T>(result.m_value, min_in_block_index[block_index]);
			first_value_in_block_offset = block_index * segment_length;
		}
		else
		{
			m_info<T> r_ref;
			m_info<T> result2;
			m_info<T> result3;
			size_t block_index;
			if (i_block_index + 1 == j_block_index)
			{
				result2 = query(i_in_block_index, segment_length - 1, group_in_query_map_table[block_type[i_block_index]], i_block_index * segment_length, value_seq);
				result3 = query(0, j_in_block_index, group_in_query_map_table[block_type[j_block_index]], j_block_index * segment_length, value_seq);
				r_ref = min_m_info(result2, result3);
			}
			else
			{
				if (i_in_block_index == 0)
				{
					result = query(i_block_index, j_block_index - 1, group_between_query_table);
				}
				else if (j_in_block_index == segment_length - 1)
				{
					result = query(i_block_index + 1, j_block_index, group_between_query_table);
				}
				else
				{
					result = query(i_block_index + 1, j_block_index - 1, group_between_query_table);
				}

				block_index = result.index; 
				result = m_info<T>(result.m_value, min_in_block_index[block_index]);

				if (i_in_block_index == 0 || j_in_block_index == segment_length - 1)
				{
					if (i_in_block_index == 0)
					{
						result2 = query(0, j_in_block_index, group_in_query_map_table[block_type[j_block_index]], j_block_index * segment_length, value_seq);
					}
					else
					{
						result2 = query(i_in_block_index, segment_length - 1, group_in_query_map_table[block_type[i_block_index]], i_block_index * segment_length, value_seq);
					}
					r_ref = min_m_info(result, result2);
				}
				else
				{
					result2 = query(i_in_block_index, segment_length - 1, group_in_query_map_table[block_type[i_block_index]], i_block_index * segment_length, value_seq);
					result3 = query(0, j_in_block_index, group_in_query_map_table[block_type[j_block_index]], j_block_index * segment_length, value_seq);
					r_ref = min_m_info(result, min_m_info(result2, result3));
				}
			}

			if (i_block_index + 1 != j_block_index && r_ref == result)
			{
				first_value_in_block_offset = block_index * segment_length;
			}
			else
			{
				if (i_block_index + 1 == j_block_index && r_ref == result3 || i_block_index + 1 != j_block_index && (i_in_block_index != 0 && j_in_block_index != segment_length - 1 && r_ref == result3))
				{
					first_value_in_block_offset = j_block_index * segment_length;
					result3.index += first_value_in_block_offset;
					return result3;
				}
				else
				{
					if (i_block_index + 1 != j_block_index && i_in_block_index == 0)
					{
						first_value_in_block_offset = j_block_index * segment_length;
					}
					else
					{
						first_value_in_block_offset = i_block_index * segment_length;
					}

					result2.index += first_value_in_block_offset;
					return result2;
				}
			}
		}
	}
	result.index += first_value_in_block_offset;
	return result;
}


int main()
{
	const int N = 256;
	vector<long long> rmq_seq(N);
	for (size_t i = 0; i < N; ++i)
	{
		rmq_seq[i] = i + 1;
	}
	shuffle(rmq_seq.begin(), rmq_seq.end(), default_random_engine());
	FischerHeunRMQ<long long> obj(rmq_seq, 0xffffffffffffffffull >> 1);
	for (size_t i = 0; i < rmq_seq.size() - 1; ++i)
	{
		for (size_t j = i + 1; j < rmq_seq.size(); ++j)
		{
			m_info<long long> r = obj.queryRMQ(i, j);
			long long m = rmq_seq[i];
			for (size_t k = i + 1; k <= j; ++k)
			{
				if (rmq_seq[k] < m)
				{
					m = rmq_seq[k];
				}
			}
			if (r.m_value == m)
			{
				cout << "[" << i << "," << j << "]查询结果正确!" << endl;
			}
			else
			{
				cout << "[" << i << "," << j << "]查询结果错误!" << endl;
				exit(-1);
			}
		}
	}
	return 0;
}


  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
抱歉,我不确定您指的是哪种RMQ算法。一般来说,RMQ是“区间最小值查询”(Range Minimum Query)的缩写,其实现算法有多种。以下是两种常见的RMQ算法实现代码,供参考: 1. 线段树RMQ算法 ```python class SegmentTree: def __init__(self, arr): self.tree = [0] * (4 * len(arr)) self.build(arr, 0, 0, len(arr) - 1) def build(self, arr, index, left, right): if left == right: self.tree[index] = arr[left] else: mid = (left + right) // 2 self.build(arr, index * 2 + 1, left, mid) self.build(arr, index * 2 + 2, mid + 1, right) self.tree[index] = min(self.tree[index * 2 + 1], self.tree[index * 2 + 2]) def query(self, index, left, right, qleft, qright): if left > qright or right < qleft: return float('inf') elif qleft <= left and qright >= right: return self.tree[index] else: mid = (left + right) // 2 return min(self.query(index * 2 + 1, left, mid, qleft, qright), self.query(index * 2 + 2, mid + 1, right, qleft, qright)) # 示例 arr = [1, 3, 2, 7, 9, 11] tree = SegmentTree(arr) print(tree.query(0, 0, len(arr) - 1, 1, 4)) # 输出2,即arr[2:5]的最小值 ``` 2. ST算法 ```python import math class ST: def __init__(self, arr): n = len(arr) k = int(math.log2(n)) self.table = [[0] * (k + 1) for _ in range(n)] for i in range(n): self.table[i][0] = arr[i] for j in range(1, k + 1): for i in range(n - 2 ** j + 1): self.table[i][j] = min(self.table[i][j - 1], self.table[i + 2 ** (j - 1)][j - 1]) def query(self, left, right): k = int(math.log2(right - left + 1)) return min(self.table[left][k], self.table[right - 2 ** k + 1][k]) # 示例 arr = [1, 3, 2, 7, 9, 11] st = ST(arr) print(st.query(1, 4)) # 输出2,即arr[2:5]的最小值 ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值