本程序将group by cube查询语句转为等价的对单次group by的结果CTE执行group by 查询的union all语句。
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <regex>
using namespace std;
// 辅助函数:去除字符串前后空格
string trim(const string& str) {
size_t start = str.find_first_not_of(" \t");
if (start == string::npos) return "";
size_t end = str.find_last_not_of(" \t");
return str.substr(start, end - start + 1);
}
// 预处理:在)后添加空格
string preprocess_input(const string& input) {
string processed = input;
size_t pos = 0;
while ((pos = processed.find(')', pos)) != string::npos) {
if (pos + 1 >= processed.size() || processed[pos+1] != ' ') {
processed.insert(pos+1, " ");
}
pos += 2;
}
return processed;
}
// 提取CUBE中的列清单
vector<string> extract_cube_columns(const string& cube_part) {
vector<string> columns;
size_t start = cube_part.find('(') + 1;
size_t end = cube_part.find(')', start);
if (start == string::npos || end == string::npos) return columns;
string cols_str = cube_part.substr(start, end - start);
size_t col_start = 0;
size_t col_end = cols_str.find(',');
while (col_end != string::npos) {
string col = cols_str.substr(col_start, col_end - col_start);
col.erase(remove_if(col.begin(), col.end(), ::isspace), col.end());
if (!col.empty()) columns.push_back(col);
col_start = col_end + 1;
col_end = cols_str.find(',', col_start);
}
// 添加最后一个列
string last_col = cols_str.substr(col_start);
last_col.erase(remove_if(last_col.begin(), last_col.end(), ::isspace), last_col.end());
if (!last_col.empty()) columns.push_back(last_col);
return columns;
}
// 生成所有可能的组合(用于CUBE),排除全true的组合
vector<vector<bool>> generate_cube_combinations(int n) {
vector<vector<bool>> combinations;
int total = 1<<n;//pow(2, n);
for (int i = 0; i < total; ++i) {
vector<bool> current;
bool all_true = true;
for (int j = 0; j < n; ++j) {
bool bit = (i >> j) & 1;
current.push_back(bit);
if (!bit) all_true = false;
}
if (!all_true) { // 排除全true的组合(即不包含group by g1,g2的情况)
combinations.push_back(current);
}
}
return combinations;
}
// 提取列表达式中的别名(修复版)
pair<string, string> extract_column_with_alias(const string& col_expr) {
string expr = col_expr;
string alias;
// 先尝试按空格分割
size_t last_space = expr.find_last_of(' ');
if (last_space != string::npos) {
alias = trim(expr.substr(last_space + 1));
expr = trim(expr.substr(0, last_space));
}
// 如果没有空格但有括号,尝试从括号后提取
if (alias.empty() && expr.find(')') != string::npos) {
size_t last_paren = expr.find_last_of(')');
alias = trim(expr.substr(last_paren + 1));
if (!alias.empty()) {
expr = trim(expr.substr(0, last_paren + 1));
}
}
// 如果还是没有别名,则整个表达式作为expr
if (alias.empty()) {
alias = expr;
}
return make_pair(expr, alias);
}
// 完整CTE CUBE改写
bool rewrite_cube_with_cte(const string& input, string& output) {
string processed_input = preprocess_input(input);
// 正则表达式匹配GROUP BY CUBE模式
regex cube_regex(R"(group\s+by\s+cube\s*\(\s*.+?\s*\))", regex_constants::icase);
smatch match;
if (!regex_search(processed_input, match, cube_regex)) {
return false;
}
string cube_part = match[0].str();
vector<string> cube_columns = extract_cube_columns(cube_part);
if (cube_columns.empty()) {
return false;
}
// 生成所有组合(排除全部分组的情况)
vector<vector<bool>> combinations = generate_cube_combinations(cube_columns.size());
// 分割原始SQL
size_t select_pos = processed_input.find("select");
size_t from_pos = processed_input.find("from");
size_t group_pos = match.position();
if (select_pos == string::npos || from_pos == string::npos || group_pos == string::npos) {
return false;
}
string before_select = processed_input.substr(0, select_pos);
string select_list = processed_input.substr(select_pos + 6, from_pos - (select_pos + 6));
string from_part = processed_input.substr(from_pos, group_pos - from_pos);
// 提取所有列及其别名(使用修复的提取函数)
vector<pair<string, string>> columns;
size_t col_start = 0;
size_t col_end = select_list.find(',');
while (col_end != string::npos) {
string col = select_list.substr(col_start, col_end - col_start);
columns.push_back(extract_column_with_alias(trim(col)));
col_start = col_end + 1;
col_end = select_list.find(',', col_start);
}
// 处理最后一个列
string last_col = select_list.substr(col_start);
columns.push_back(extract_column_with_alias(trim(last_col)));
// 构建重写后的查询
stringstream rewritten_query;
// 1. 构建CTE基础查询
rewritten_query << "with t0 as materialized \n(";
rewritten_query << before_select << "select ";
// 输出所有列
for (size_t i = 0; i < columns.size(); ++i) {
if (i != 0) rewritten_query << ",";
rewritten_query << columns[i].first;
if (!columns[i].second.empty() && columns[i].second != columns[i].first) {
rewritten_query << " " << columns[i].second;
}
}
rewritten_query << " " << from_part << " group by ";
for (size_t i = 0; i < cube_columns.size(); ++i) {
if (i != 0) rewritten_query << ",";
rewritten_query << cube_columns[i];
}
rewritten_query << ")\n";
// 2. 第一个查询直接从t0选择(固定格式)
rewritten_query << "select ";
for (size_t i = 0; i < columns.size(); ++i) {
if (i != 0) rewritten_query << ",";
if (i < cube_columns.size()) {
rewritten_query << cube_columns[i];
} else {
rewritten_query << columns[i].second;
}
}
rewritten_query << " from t0\n";
// 3. 构建各层汇总查询(必须包含至少一个null)
for (size_t i = 0; i < combinations.size(); ++i) {
rewritten_query << "union all\n";
rewritten_query << "select ";
// 分组列
for (size_t j = 0; j < cube_columns.size(); ++j) {
if (j != 0) rewritten_query << ",";
rewritten_query << (combinations[i][j] ? cube_columns[j] : "null");
}
// 聚合列(固定格式sum(别名)别名)
for (size_t k = cube_columns.size(); k < columns.size(); ++k) {
if (k != 0 || !cube_columns.empty()) rewritten_query << ",";
rewritten_query << "sum(" << columns[k].second << ")" << columns[k].second;
}
rewritten_query << " from t0 group by ";
for (size_t j = 0; j < cube_columns.size(); ++j) {
if (j != 0) rewritten_query << ",";
rewritten_query << (combinations[i][j] ? cube_columns[j] : "null");
}
rewritten_query << "\n";
}
output = rewritten_query.str();
return true;
}
// 测试函数
void test() {
string input = "select g1,g2,v2, count(v1)c1,sum(v1)s1,avg(v1)a1 from t group by cube(g1,g2, v2)";
string actual;
if (rewrite_cube_with_cte(input, actual)) {
cout << "input: " << input << endl;
cout << "output:\n" << actual << endl;
} else {
cout << "error" << endl;
}
}
int main() {
test();
return 0;
}