#include "HyperSplit.h"
HyperSplit::HyperSplit(uint64_t binth) {
this->binth = binth;
gChildCount = 0;
gNumLeafNode = 0;
gNumLeafNode = 0;
gWstDepth = 0;
gAvgDepth = 0;
gNumTotalNonOverlappings = 0;
gNumTreeNode = 0;
}
void HyperSplit::ConstructClassifier(const vector<Rule> &rules) {
numRules = int(rules.size());
root = ConstructHSTree(rules, 0);
}
int HyperSplit::ClassifyAPacket(const Packet &packet) {
if(!root) {
return -1;
}
hs_node *node = root;
int Query = 0;
int matchPri = -1;
while(!node->isleaf) {
Query ++;
if(packet[node->d2s] <= node->thresh) {
node = node->child[0];
} else {
node = node->child[1];
}
}
for(auto & rule : node->ruleset) {
Query ++;
if(rule.MatchesPacket(packet)){
matchPri = rule.priority;
break;
}
}
QueryUpdate(Query);
return matchPri;
}
int HyperSplit::ClassifyAPacket(const Packet &packet, uint64_t &Query) {
if(!root) {
return -1;
}
hs_node *node = root;
int matchPri = -1;
while(!node->isleaf) {
Query ++;
if(packet[node->d2s] <= node->thresh) {
node = node->child[0];
} else {
node = node->child[1];
}
}
for(auto & rule : node->ruleset) {
Query ++;
if(rule.MatchesPacket(packet)){
matchPri = rule.priority;
break;
}
}
return matchPri;
}
void HyperSplit::DeleteRule(const Rule &rule) {
}
void HyperSplit::InsertRule(const Rule &rule) {
}
hs_node *HyperSplit::ConstructHSTree(const vector<Rule> &rules, int depth) {
//传入rulesvector和深度
numRules = (int) rules.size();
if (rules.size() <= binth) {
//如果规则规模比设定值小,则为一个leaf
gChildCount ++;
gNumLeafNode ++;
gAvgDepth += depth;
gWstDepth = max(gWstDepth, uint64_t (depth));
auto *node = new hs_node(0, depth, 0, true, rules);
return node;
}
/* generate segments for input filtset */
unsigned int dim, num, pos;
unsigned int maxDiffSegPts = 1; /* maximum different segment points */
unsigned int d2s = 0; /* dimension to split (with max diffseg) */
uint64_t thresh;
unsigned int range[2][2]; /* sub-space ranges for child-nodes */
vector<vector<uint64_t> > segPoints(MAXDIMENSIONS, vector<uint64_t>(rules.size() * 2, 0));
vector<vector<uint64_t> > segPointsInfo(MAXDIMENSIONS, vector<uint64_t>(rules.size() * 2, 0));
vector<uint64_t> tempSegPoints(2 * rules.size());
double hightAvg, hightAll;
vector<Rule> childRuleSet;
for (dim = 0; dim < MAXDIMENSIONS; dim++) {
for (num = 0; num < rules.size(); num++) {
segPoints[dim][2 * num] = rules[num].range[dim][LowDim];
segPoints[dim][2 * num + 1] = rules[num].range[dim][HighDim];
}
}
//segPoints[dim][num<1] 为rule[num]的最小Dim
//segPoints[dim][num < 1 & 1] 为rule[num]的最大Dim
//对SegmentPoints排序
// Sort the Segment Points
for (dim = 0; dim < MAXDIMENSIONS; dim++) {
sort(segPoints[dim].begin(), segPoints[dim].end());
}
/*Compress the Segment Points, and select the dimension to split (d2s)*/
hightAvg = double(2 * rules.size() + 1);
for (dim = 0; dim < MAXDIMENSIONS; dim++) {
vector<uint64_t> hightList;
uint64_t diffSegPts = 1;
tempSegPoints[0] = segPoints[dim][0];
for (num = 1; num < 2 * rules.size(); num++) {
//如果后一个的区间端点值与前一个不相等,则将它存入temp
//diffSegPts最终记录有多少个不同取值,=去重操作
if (segPoints[dim][num] != tempSegPoints[diffSegPts - 1]) {
tempSegPoints[diffSegPts] = segPoints[dim][num];
diffSegPts++;
}
}
//查看一下是否这个段点是否是一些规则的起点和终点
/*Span the segment points which is both start and end of some rules*/
pos = 0;
for (num = 0; num < diffSegPts; num++) {
int ifStart = 0, ifEnd = 0;
//tempSegPoints是对segPoints去重操作后的数组
segPoints[dim][pos] = tempSegPoints[num];
//pos do what?
//对于所有规则,判断tsP数组存的数是否为一个起始端点
for (const auto & rule : rules) {
if (rule.range[dim][LowDim] == tempSegPoints[num]) {
ifStart = 1;
break;
}
}
//对所有规则,判断tsP数组存的数是否为一个终止端点
for (const auto & rule : rules) {
if (rule.range[dim][HighDim] == tempSegPoints[num]) {
ifEnd = 1;
break;
}
}
//如果又是起始点又是终止点,则我们用Info数组记录
if (ifStart && ifEnd) {
//
segPointsInfo[dim][pos] = 0;
pos++;
//更新segPoints 为这个又是一个区间起点,又是一个区间终点的值
//这样似乎segp存的就只是又是起点又是终点的值
segPoints[dim][pos] = tempSegPoints[num];
segPointsInfo[dim][pos] = 1;
pos++;
} else if (ifStart) {
segPointsInfo[dim][pos] = 0;
pos++;
} else {
segPointsInfo[dim][pos] = 1;
pos++;
}
}
/* now pos is the total number of points in the spanned segment point list */
if (depth == 0) {
gNumNonOverLappings[dim] = pos;
gNumTotalNonOverlappings *= (uint64_t) pos;
}
if (pos >= 3) {
hightAll = 0;
hightList.resize(pos);
for (int i = 0; i < pos - 1; i++) {
hightList[i] = 0;
for (const auto & rule : rules) {
//统计每个区间内包含了多少个规则
if (rule.range[dim][LowDim] <= segPoints[dim][i] &&
rule.range[dim][HighDim] >= segPoints[dim][i + 1]) {
hightList[i]++;
hightAll++;
}
}
}
//选取维度
if (hightAvg > hightAll / (pos - 1)) {
double hightSum = 0;
/* select current dimension */
d2s = dim;
hightAvg = hightAll / (pos - 1);
/* the first segment MUST belong to the left child */
hightSum += double(hightList[0]);
for (num = 1; num < pos - 1; num++) {
//=0表示是起始点,
if (segPointsInfo[d2s][num] == 0) {
thresh = segPoints[d2s][num] - 1;
} else {
thresh = segPoints[d2s][num];
}
//这里是在找分割点
if (hightSum > hightAll / 2) {
break;
}
hightSum += double(hightList[num]);
}
range[0][0] = segPoints[d2s][0];
range[0][1] = thresh;
range[1][0] = thresh + 1;
range[1][1] = segPoints[d2s][pos - 1];
}
} // pos >=3
if (maxDiffSegPts < pos) {
maxDiffSegPts = pos;
}
}
if (maxDiffSegPts <= 2) {
auto *node = new hs_node(0, depth, 0, true, rules);
gChildCount++;
gNumLeafNode++;
if (gNumLeafNode % 1000000 == 0)
printf(".");
if (gWstDepth < depth)
gWstDepth = depth;
gAvgDepth += depth;
return node;
}
auto *node = new hs_node(d2s, depth, thresh, false, rules);
gNumTreeNode ++;
// Generate left child rule list
// children
vector<Rule> leftRule, rightRule;
for(const auto &rule : rules) {
// left
// bool leftFlag = true, rightFlag = true;
if(rule.range[d2s][LowDim] <= range[0][HighDim] &&
rule.range[d2s][HighDim] >= range[0][LowDim]) {
Rule r = rule;
r.range[d2s][LowDim] = max(r.range[d2s][LowDim], range[0][LowDim]);
r.range[d2s][HighDim] = min(r.range[d2s][HighDim], range[0][HighDim]);
leftRule.push_back(r);
}
// right
if(rule.range[d2s][LowDim] <= range[1][HighDim] &&
rule.range[d2s][HighDim] >= range[1][LowDim]) {
Rule r = rule;
r.range[d2s][LowDim] = max(r.range[d2s][LowDim], range[1][LowDim]);
r.range[d2s][HighDim] = min(r.range[d2s][HighDim], range[1][HighDim]);
rightRule.push_back(r);
}
}
if(leftRule.size() == rules.size() || rightRule.size() == rules.size()) {
node->isleaf = true;
return node;
}
node->child[0] = ConstructHSTree(leftRule, depth + 1);
node->child[1] = ConstructHSTree(rightRule, depth + 1);
return node;
}
Memory HyperSplit::MemSizeBytes() const {
Memory totMemory = gNumTreeNode * TREE_NODE_SIZE + gNumLeafNode * LEAF_NODE_SIZE + numRules * PTR_SIZE;
return totMemory;
}
hypersplit
于 2022-03-26 15:33:54 首次发布