hypersplit


#include "HyperSplit.h"

HyperSplit::HyperSplit(uint64_t binth) {
    this->binth = binth;
    gChildCount = 0;
    gNumLeafNode = 0;
    gNumLeafNode = 0;
    gWstDepth = 0;
    gAvgDepth = 0;
    gNumTotalNonOverlappings = 0;
    gNumTreeNode = 0;
}

void HyperSplit::ConstructClassifier(const vector<Rule> &rules) {
    numRules = int(rules.size());
    root = ConstructHSTree(rules, 0);
}


int HyperSplit::ClassifyAPacket(const Packet &packet) {
    if(!root) {
        return -1;
    }
    hs_node *node = root;
    int Query = 0;
    int matchPri = -1;

    while(!node->isleaf) {
        Query ++;
        if(packet[node->d2s] <= node->thresh) {
            node = node->child[0];
        } else {
            node = node->child[1];
        }
    }
    for(auto & rule : node->ruleset) {
        Query ++;
        if(rule.MatchesPacket(packet)){
            matchPri = rule.priority;
            break;
        }
    }
    QueryUpdate(Query);
    return matchPri;
}

int HyperSplit::ClassifyAPacket(const Packet &packet, uint64_t &Query) {
    if(!root) {
        return -1;
    }
    hs_node *node = root;
    int matchPri = -1;

    while(!node->isleaf) {
        Query ++;
        if(packet[node->d2s] <= node->thresh) {
            node = node->child[0];
        } else {
            node = node->child[1];
        }
    }
    for(auto & rule : node->ruleset) {
        Query ++;
        if(rule.MatchesPacket(packet)){
            matchPri = rule.priority;
            break;
        }
    }
    return matchPri;
}


void HyperSplit::DeleteRule(const Rule &rule) {

}

void HyperSplit::InsertRule(const Rule &rule) {

}

hs_node *HyperSplit::ConstructHSTree(const vector<Rule> &rules, int depth) {
    //传入rulesvector和深度
    numRules = (int) rules.size();
    if (rules.size() <= binth) {
        //如果规则规模比设定值小,则为一个leaf
        gChildCount ++;
        gNumLeafNode ++;
        gAvgDepth += depth;
        gWstDepth = max(gWstDepth, uint64_t (depth));
        auto *node = new hs_node(0, depth, 0, true, rules);
        return node;
    }

    /* generate segments for input filtset */
    unsigned int dim, num, pos;
    unsigned int maxDiffSegPts = 1;    /* maximum different segment points */
    unsigned int d2s = 0;        /* dimension to split (with max diffseg) */
    uint64_t thresh;
    unsigned int range[2][2];    /* sub-space ranges for child-nodes */

    vector<vector<uint64_t> > segPoints(MAXDIMENSIONS, vector<uint64_t>(rules.size() * 2, 0));
    vector<vector<uint64_t> > segPointsInfo(MAXDIMENSIONS, vector<uint64_t>(rules.size() * 2, 0));
    vector<uint64_t> tempSegPoints(2 * rules.size());
    double hightAvg, hightAll;
    vector<Rule> childRuleSet;

    for (dim = 0; dim < MAXDIMENSIONS; dim++) {
        for (num = 0; num < rules.size(); num++) {
            segPoints[dim][2 * num] = rules[num].range[dim][LowDim];
            segPoints[dim][2 * num + 1] = rules[num].range[dim][HighDim];
        }
    }
    //segPoints[dim][num<1] 为rule[num]的最小Dim 
    //segPoints[dim][num < 1 & 1] 为rule[num]的最大Dim

    //对SegmentPoints排序
//    Sort the Segment Points
    for (dim = 0; dim < MAXDIMENSIONS; dim++) {
        sort(segPoints[dim].begin(), segPoints[dim].end());
    }

    /*Compress the Segment Points, and select the dimension to split (d2s)*/
    hightAvg = double(2 * rules.size() + 1);
    for (dim = 0; dim < MAXDIMENSIONS; dim++) {
        vector<uint64_t> hightList;
        uint64_t diffSegPts = 1;
        tempSegPoints[0] = segPoints[dim][0];
        for (num = 1; num < 2 * rules.size(); num++) {
            //如果后一个的区间端点值与前一个不相等,则将它存入temp
            //diffSegPts最终记录有多少个不同取值,=去重操作
            if (segPoints[dim][num] != tempSegPoints[diffSegPts - 1]) {
                tempSegPoints[diffSegPts] = segPoints[dim][num];
                diffSegPts++;
            }
        }
        //查看一下是否这个段点是否是一些规则的起点和终点
        /*Span the segment points which is both start and end of some rules*/
        pos = 0;
        for (num = 0; num < diffSegPts; num++) {
            int ifStart = 0, ifEnd = 0;
            //tempSegPoints是对segPoints去重操作后的数组
            segPoints[dim][pos] = tempSegPoints[num];
            //pos do what?
            //对于所有规则,判断tsP数组存的数是否为一个起始端点
            for (const auto & rule : rules) {
                if (rule.range[dim][LowDim] == tempSegPoints[num]) {
                    ifStart = 1;
                    break;
                }
            }
            //对所有规则,判断tsP数组存的数是否为一个终止端点
            for (const auto & rule : rules) {
                if (rule.range[dim][HighDim] == tempSegPoints[num]) {
                    ifEnd = 1;
                    break;
                }
            }
            //如果又是起始点又是终止点,则我们用Info数组记录
            if (ifStart && ifEnd) {
                //
                segPointsInfo[dim][pos] = 0;
                pos++;
                //更新segPoints 为这个又是一个区间起点,又是一个区间终点的值
                //这样似乎segp存的就只是又是起点又是终点的值
                segPoints[dim][pos] = tempSegPoints[num];
                segPointsInfo[dim][pos] = 1;
                pos++;
            } else if (ifStart) {
                segPointsInfo[dim][pos] = 0;
                pos++;
            } else {
                segPointsInfo[dim][pos] = 1;
                pos++;
            }
        }

        /* now pos is the total number of points in the spanned segment point list */
        if (depth == 0) {
            gNumNonOverLappings[dim] = pos;
            gNumTotalNonOverlappings *= (uint64_t) pos;
        }

        if (pos >= 3) {
            hightAll = 0;
            hightList.resize(pos);
            for (int i = 0; i < pos - 1; i++) {
                hightList[i] = 0;
                for (const auto & rule : rules) {
                    //统计每个区间内包含了多少个规则
                    if (rule.range[dim][LowDim] <= segPoints[dim][i] &&
                            rule.range[dim][HighDim] >= segPoints[dim][i + 1]) {
                        hightList[i]++;
                        hightAll++;
                    }
                }
            }

            //选取维度
            if (hightAvg > hightAll / (pos - 1)) {
                double hightSum = 0;
                /* select current dimension */
                d2s = dim;
                hightAvg = hightAll / (pos - 1);

                /* the first segment MUST belong to the left child */
                hightSum += double(hightList[0]);
                for (num = 1; num < pos - 1; num++) {
                    //=0表示是起始点,
                    if (segPointsInfo[d2s][num] == 0) {
                        thresh = segPoints[d2s][num] - 1;
                    } else {
                        thresh = segPoints[d2s][num];
                    }
                    //这里是在找分割点
                    if (hightSum > hightAll / 2) {
                        break;
                    }
                    hightSum += double(hightList[num]);
                }
                range[0][0] = segPoints[d2s][0];
                range[0][1] = thresh;
                range[1][0] = thresh + 1;
                range[1][1] = segPoints[d2s][pos - 1];

            }
        } // pos >=3
        if (maxDiffSegPts < pos) {
            maxDiffSegPts = pos;
        }
    }

    if (maxDiffSegPts <= 2) {
        auto *node = new hs_node(0, depth, 0, true, rules);
        gChildCount++;
        gNumLeafNode++;
        if (gNumLeafNode % 1000000 == 0)
            printf(".");
        if (gWstDepth < depth)
            gWstDepth = depth;
        gAvgDepth += depth;
        return node;
    }
    auto *node = new hs_node(d2s, depth, thresh, false, rules);
    gNumTreeNode ++;

//    Generate left child rule list

// children
    vector<Rule> leftRule, rightRule;
    for(const auto &rule : rules) {
        // left
//        bool leftFlag = true, rightFlag = true;
        if(rule.range[d2s][LowDim] <= range[0][HighDim] &&
            rule.range[d2s][HighDim] >= range[0][LowDim]) {
            Rule r = rule;
            r.range[d2s][LowDim] = max(r.range[d2s][LowDim], range[0][LowDim]);
            r.range[d2s][HighDim] = min(r.range[d2s][HighDim], range[0][HighDim]);
            leftRule.push_back(r);
        }
        // right
        if(rule.range[d2s][LowDim] <= range[1][HighDim] &&
            rule.range[d2s][HighDim] >= range[1][LowDim]) {
            Rule r = rule;
            r.range[d2s][LowDim] = max(r.range[d2s][LowDim], range[1][LowDim]);
            r.range[d2s][HighDim] = min(r.range[d2s][HighDim], range[1][HighDim]);
            rightRule.push_back(r);
        }
    }

    if(leftRule.size() == rules.size() || rightRule.size() == rules.size()) {
        node->isleaf = true;
        return node;
    }
    node->child[0] = ConstructHSTree(leftRule, depth + 1);
    node->child[1] = ConstructHSTree(rightRule, depth + 1);

    return node;
}

Memory HyperSplit::MemSizeBytes() const {
    Memory totMemory = gNumTreeNode * TREE_NODE_SIZE + gNumLeafNode * LEAF_NODE_SIZE + numRules * PTR_SIZE;
    return totMemory;
}



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值