netplier全部源码解析

[^Netplier source code parsing, completed by Jiachi Liu.]:

文章关键翻译

our idea.

我们对来自客户端和服务器端的消息使用多序列对齐 (MSA) 算法,并将消息划分为字段列表。 MSA 倾向于保守,只生成一个完整的字段列表,这提供了一个坚实的起点。对于每个字段,我们引入一个随机变量来表示成为关键字的概率。假设一个字段是关键字,消息可以根据字段的值被分组到不同的簇中,这些簇将满足一些约束,例如消息相似性约束、远程耦合约束、结构一致性约束和维度约束。对于每个约束,我们计算概率作为我们观察到的合规程度。有了这些概率,我们然后执行概率推理来推导表示我们假设的随机变量的后验概率,即当前字段是关键字。检查完所有字段后,我们可以选择概率最大的一个作为关键字,并用它来对消息进行聚类。

main.py

这个文件作为netplier的程序入口,为用户提供简易的命令行操作。

模块引入

import argparse
import sys
import os
import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
#logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)

from netplier import NetPlier
from processing import Processing
from alignment import Alignment
from clustering import Clustering

argparse:直接在命令行中就可以向程序中传入参数并让程序运行。

​ 学习链接:argparse模块用法实例详解 - 知乎 (zhihu.com)

logging:配置日志的输出,将level级别以上的信息输出到控制台。

​ 学习链接:https://zhuanlan.zhihu.com/p/166671955

用户输入处理

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()

    parser.add_argument('-i', '--input', required=True, dest='filepath_input', help='filepath of input trace')
    parser.add_argument('-t', '--type', dest='protocol_type', help='type of the protocol (for generating the ground truth): \
        dhcp, dnp3, icmp, modbus, ntp, smb, smb2, tftp, zeroaccess')
    parser.add_argument('-o', '--output_dir', dest='output_dir', default='tmp_netplier/', help='output directory')
    parser.add_argument('-l', '--layer', dest='layer', default=5, type=int, help='the layer of the protocol')
    parser.add_argument('-m', '--mafft', dest='mafft_mode', default='ginsi', help='the mode of mafft: [ginsi, linsi, einsi]')
    parser.add_argument('-mt', '--multithread', dest='multithread', default=False, action='store_true', help='run mafft with multi threads')

    args = parser.parse_args()

问题:怎么对未知消息格式进行分析

Processing.py调用

    p = Processing(filepath=args.filepath_input, protocol_type=args.protocol_type, layer=args.layer)  
    # p.print_dataset_info()   
    mode = args.mafft_mode
    if args.protocol_type in['dnp3']: # tftp
        mode = 'linsi'

引用Processing.py文件,输入文件参数、协议参数和协议层;将dnp3的对齐模式改为linsi。

netplier.py调用

    netplier = NetPlier(messages=p.messages, direction_list=p.direction_list, output_dir=args.output_dir, mode=mode, multithread=args.multithread)
    fid_inferred = netplier.execute()

调用netplier文件并执行execute()函数,得到推断的字段

聚类(1)

    messages_aligned = Alignment.get_messages_aligned(netplier.messages, os.path.join(netplier.output_dir, Alignment.FILENAME_OUTPUT_ONELINE))
    messages_request, messages_response = Processing.divide_msgs_by_directionlist(netplier.messages, netplier.direction_list)
    messages_request_aligned, messages_response_aligned = Processing.divide_msgs_by_directionlist(messages_aligned, netplier.direction_list)

所有消息对齐;分类得出服务器端消息和客户端消息;将服务器端消息和客户端消息分别对齐。

聚类(2)

    clustering = Clustering(fields=netplier.fields, protocol_type=args.protocol_type)
    clustering_result_request_true = clustering.cluster_by_kw_true(messages_request)
    clustering_result_response_true = clustering.cluster_by_kw_true(messages_response)
    clustering_result_request_netplier = clustering.cluster_by_kw_inferred(fid_inferred, messages_request_aligned)
    clustering_result_response_netplier = clustering.cluster_by_kw_inferred(fid_inferred, messages_response_aligned)
    clustering.evaluation([clustering_result_request_true, clustering_result_response_true], [clustering_result_request_netplier, clustering_result_response_netplier])

包含客户端消息聚类和服务器端消息聚类, 以及聚类结果的评估。

首先是针对消息真实的关键字进行聚类,然后把推断的候选关键字聚类结果与真实的对比进行评估。


processing.py

对获取的消息进行预处理,提取有用信息。

模块引入

import logging
import copy
import struct
from netzob.Import.PCAPImporter.all import *
from netzob.Model.Vocabulary.Session import Session

copy.copy:浅复制,与copy.deepcopy相反,复制的对象并没有真正独立起来

​ 学习链接:(47条消息) python 模块 copy 复制详解_python copy_ghostwritten的博客-CSDN博客

struct:用于在字节字符串和python原生数据之间转换函数

​ 学习链接:https://blog.csdn.net/qdPython/article/details/115550281

后两条调用了了netzob协议的函数。

PCAPImporter是利用netzob的方式读取pcap文件。

Processing类初始化

class Processing:
    MAX_LEN = 500 #100 // reduce the time for MSA

    def __init__(self, filepath, protocol_type=None, layer=5, messages=None):
        self.filepath = filepath
        self.protocol_type = protocol_type
        self.layer = layer
        self.messages = messages
        self.direction_list = list()

        if self.protocol_type:
            assert self.protocol_type in ['dhcp', 'dnp3', 'icmp', 'modbus', 'ntp', 'smb', 'smb2', 'tftp', 'zeroaccess'], 'the protocol_type is unknown'
        self.import_messages()
        self.get_msgs_directionlist()

对协议类型做出判断,不在范围内则报错;进而导入消息得出消息列表。

direction_list中元素为1或者0,分别代表每条消息来自于请求端或者响应端。

导入信息 import_messages

def import_messages(self):
        print("[++++++++] Import messages")
        # ICMP: layer = 3
        if self.protocol_type == 'icmp':
            self.layer = 3
        messages = PCAPImporter.readFile(filePath=self.filepath, importLayer=self.layer).values()

如果是icmp消息,则设为网络层,使用netzob里的readFile函数(readFile->readFiles->readMessages)读取pcap的全部消息,返回的是消息列表。

这里message包含一条报文的所有信息,包含源地址,端口号,日期,数据等,数据是以字节形式进制存储的,可用binascii.hexlify进行解码。

        ## Filter messages
        # extract from IP msgs
        if self.protocol_type == "icmp":
            for message in messages:
                len_header = message.data[0] & 0x0000000f
                startIndex = len_header * 4 #*32/8
                message.data = message.data[startIndex:]
        # in mb2, some msgs contain more than one mbtcp
        elif self.protocol_type == 'modbus':
            for i, message in enumerate(messages):
                length = int.from_bytes(message.data[4:4+2], byteorder='big', signed=True)
                if len(message.data) != length + 6:
                    message.data = message.data[:length+6]
        elif self.protocol_type == 'smb':
            for message in messages[::-1]:
                # delete not smb msgs
                if message.data[4:8].hex() != "ff534d42":#smb协议的标志性关键字段
                    messages.remove(message)
                if len(message.data) > 500:
                    message.data = message.data[:500]
        elif self.protocol_type == 'smb2':
            for message in messages[::-1]:
                if message.data[4:8].hex() != "fe534d42":
                    messages.remove(message)
                if len(message.data) > 500:
                    message.data = message.data[:500]
        elif self.protocol_type == 'zeroaccess':
            for message in messages:
                message.data = self.decrypt_za_msg(message.data)

        # MAX_LEN = 500 
        for message in messages:
            if len(message.data) > Processing.MAX_LEN:
                message.data = message.data[:Processing.MAX_LEN]

        self.messages = messages

过滤信息,提取有关ip协议的信息,从icmp、modbus、smb、smb2、zeroaccess几个协议中提炼出来有用的ip协议信息,把不属于分析范畴的协议内容给删掉,并且只选取消息的前500位,最后得出来经过第一步处理的信息列表。

zeroaccess协议那里是对报文进行了解密。

问题:类似smb协议标志“ff534d42”为何在第四到七位,真实的wireshark报文中前面的数据代表啥含义?

特殊报文解密decrypt_za_msg

    def decrypt_za_msg(self, messagedata_encrypted):
        crc32 = struct.unpack("<I", messagedata_encrypted[0:4])[0]
        if crc32 == 0:
            return messagedata_encrypted

        key = 0x66747032
        result = []
        for i in range(0, len(messagedata_encrypted), 4):
            if (i + 4) >= len(messagedata_encrypted):
                break
            sub_data = struct.unpack("<I", messagedata_encrypted[i:i+4])[0]
            xored_subdata = sub_data ^ key #解密过程:m = e ^ k
            decrpted_data = struct.pack("<I", xored_subdata)
            result.append(decrpted_data.hex())
            key = ((key << 1) & 0xffffffff | key >> 31) #密钥轮询
        messagedata_decrypted = ''.join(result)
        
        return bytes.fromhex(messagedata_decrypted)

对zeroaccess的报文进行解密,根据前四个字节判断是否加密。key为0x66747032,以四个字节为一个轮询,与密钥进行运算获取明文,并将后续的明文和密钥同时进行迭代,得出最终的明文序列。

生成消息的方向列表 get_msgs_directionlist

    ## generate direction list
    def get_msgs_directionlist(self):
        assert self.messages is not None, 'the messages could not be None'
        if not self.protocol_type or self.protocol_type == "tftp":
            direction_list = self.get_msgs_directionlist_by_sessions()
        else: ## get the direction by specification
            direction_list = list()
            for message in self.messages:
                d = self.get_msg_direction_by_specification(message)
                if d != 0 and d != 1:
                    logging.error("Error: GetMsgsDirectionlistBySpecification")
                direction_list.append(d)
        self.direction_list = direction_list

首先判断消息内容不为空。

如果没有指明消息类型或者消息类型为tftp,则使用get_msgs_directionlist_by_sessions()生成消息方向列表;否则,使用get_msg_direction_by_specification (message)规范生成。

direction_list中元素为1或者0,分别代表每条消息来自于请求端或者响应端。

问题:没有指明类型消息的判断在类初始化过程中不是已经断言过了吗?

通过会话方法生成方向列表 get_msgs_directionlist_by_sessions

    def get_msgs_directionlist_by_sessions(self):
        dict_idtoi = dict()
        for i,message in enumerate(self.messages):
            dict_idtoi[message.id] = i
        direction_list = [-1]*len(self.messages)
        sessions = Session(self.messages) #引用netzob源码中的Session类
        #print(sessions.getEndpointsList())
        #logging.info("Number of Sessions: {0}".format(len(sessions.getTrueSessions())))
        for session in sessions.getTrueSessions():  #检索每一个真实会话
            messages_list = list(session.messages.values())
            messages_list = sorted(messages_list, key=lambda x:x.date)
            srcIP = messages_list[0].source #为了获取源IP的地址
            for message in messages_list:
                if message.source == srcIP:
                    result = 0
                else:
                    result = 1
                direction_list[dict_idtoi[message.id]] = result
        return direction_list

sorted:按照x.date的值进行排序,用了python中二维数据的知识。

​ 学习链接:https://blog.csdn.net/Jeffxu_lib/article/details/88650431

这个函数功能借助netzob的方法对消息进行方向分类,通过与提取的源ip地址比较,将源头地址的消息设为0,将目的地址的消息设为1.

sessions.getTrueSessions():netzob中的一个函数,功能如下。

    >>> from netzob.all import *
    >>> msg1 = RawMessage("SYN", source="A", destination="B")
    >>> msg2 = RawMessage("SYN/ACK", source="B", destination="A")
    >>> msg3 = RawMessage("ACK", source="A", destination="C")
    >>> session = Session([msg1, msg2, msg3])
    >>> print(len(session.getTrueSessions()))
    2
    >>> for trueSession in session.getTrueSessions():
    ...    print(trueSession.name)
    Session: 'A' - 'B'
    Session: 'A' - 'C'

可以生成会话中参与方信息,其中Session函数也是Netzob的功能函数,用来对消息进行归拢并自动排序。

>>> import time
>>> from netzob.all import *
>>> # we create 3 messages
>>> msg1 = RawMessage("ACK", source="A", destination="B", date=time.mktime(time.strptime("9 Aug 13 10:45:05", "%d %b %y %H:%M:%S")))
>>> msg2 = RawMessage("SYN", source="A", destination="B", date=time.mktime(time.strptime("9 Aug 13 10:45:01", "%d %b %y %H:%M:%S")))
>>> msg3 = RawMessage("SYN/ACK", source="B", destination="A", date=time.mktime(time.strptime("9 Aug 13 10:45:03", "%d %b %y %H:%M:%S")))
>>> session = Session([msg1, msg2, msg3])
>>> print(session.messages.values()[0].data)
SYN
>>> print(session.messages.values()[1].data)
SYN/ACK
>>> print(session.messages.values()[2].data)
ACK

规范生成方向列表 get_msg_direction_by_specification

    def get_msg_direction_by_specification(self, message):
        ##0: request; 1: response
        result = -1

        if self.protocol_type == "dhcp":
            #DHCP: original msgs data[0]; aligned msg (hex) data[1]
            if message.data[0] == 1:
                result = 0
            elif message.data[0] == 2:
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(message.data[0]))

根据DHCP报文的特征来判断,内容的第一字节是1的情况下为请求报文,2的情况下是响应报文。

        elif self.protocol_type == "dnp3":
            f = (message.data[3] >> 7) & 0x01
            #1: from master
            #0: from outstation
            if f == 1:
                result = 0
            elif f == 0:
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(f))

dnp3协议中第四个字节为44的情况下为请求端,C4的情况下为响应端

        elif self.protocol_type == "ftp":
            port_source, port_destination = message.source.split(":")[1], message.destination.split(":")[1]
            server_port = ["20", "21"]
            if port_source in server_port:
                result = 1
            elif port_destination in server_port:
                result = 0
            else:
                logging.error("Can not decide the direction of msg port: {} {}".format(message.source, message.destination))
        elif self.protocol_type == "icmp":
            ##9/10: not sure
            if message.data[0] in [8, 13, 15, 17, 10]:
                result = 0
            elif message.data[0] in [0, 3, 4, 5, 11, 12, 14, 16, 18, 9]:
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(message.data[0]))
        elif self.protocol_type == "modbus":
            port_source, port_destination = message.source.split(":")[1], message.destination.split(":")[1]
            if port_source == "502":
                result = 1
            elif port_destination == "502":
                result = 0
            else:
                logging.error("Can not decide the direction of msg port: {} {}".format(message.source, message.destination))
        elif self.protocol_type == "ntp":
            f = message.data[0] & 0x07
            #1: symmetric active; 2: Symmetric Passive
            #3: client; 4: server;
            #5: broadcast server; 6: Broadcast Client
            if f == 1 or f == 3 or f == 5:
                result = 0
            elif f == 2 or f == 4 or f == 6:
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(f))
        elif self.protocol_type == "smb":
            smb_flag = message.data[4+9]
            direction = smb_flag & 0x80
            if direction == 0:
                result = 0
            elif direction == 128:
                result = 1
            else:
                print("Can not decide the direction of msg: {}".format(direction))
        elif self.protocol_type == "smb2":
            #print(message.data[4+16:4+16+4].hex())
            smb_flag = struct.unpack("<I", message.data[4+16:4+16+4])[0]
            direction = smb_flag & 0x1
            #print(direction, type(direction))
            if direction == 0:
                result = 0
            elif direction == 1:
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(direction))
        elif self.protocol_type == "zeroaccess":
            #g: 103; r: 114; n: 110
            if message.data[7] == 103:
                result = 0
            elif (message.data[7] == 114) or (message.data[7] == 110):
                result = 1
            else:
                logging.error("Can not decide the direction of msg: {}".format(message.data[7]))
        else:
            logging.error("The protocol_type is not unknown to detect direction")

        if result == -1:
            logging.error("Error: can't decide the drection: {0}".format(message.data))

        return result

以上针对每一种协议的特征进行规范生成以区分每条消息的方向。

打印数据信息 print_dataset_info

    def print_dataset_info(self):
        assert self.protocol_type is not None, 'need the protocol_type to get dataset info'
        print("\n[++++++++] Get Dataset Info")

        ## Number of msgs
        messages_request, messages_response = Processing.divide_msgs_by_directionlist(self.messages, self.direction_list)
        print("Total msg number: {0}\nRequest msg number: {1}\nResponse msg number: {2}\n".format(len(self.messages), len(messages_request), len(messages_response)))

通过divide_msgs_by_directionlist函数将消息列表分为请求方向和响应方向,打印每个集合的消息数量

        types_list_request = [self.get_true_keyword(message) for message in messages_request]
        types_list_response = [self.get_true_keyword(message) for message in messages_response]
        print("Request Symbols: {}".format(set(types_list_request)))
        print("Response Symbols: {}".format(set(types_list_response)))

        print("Number of request symbols: {0}".format(len(set(types_list_request))))
        for s in set(types_list_request):
            print("  Symbol {0} msgs numbers: {1}".format(s, types_list_request.count(s)))
        print("Number of response symbols: {0}".format(len(set(types_list_response))))
        for s in set(types_list_response):
            print("  Symbol {0} msgs numbers: {1}".format(s, types_list_response.count(s)))

根据get_true_keyword函数输出真正的消息类型象征,并统计每个方向列表的数量

        messages = copy.deepcopy(self.messages)
        sessions = Session(messages)
        for i in range(len(self.direction_list)):
            data = [messages[i].data, self.direction_list[i]]
            messages[i].data = data
        num_of_session = len(sessions.getTrueSessions())
        print("\nNumber of Sessions: {0}".format(num_of_session))
        print("[++++++++] End\n")

输出会话信息(会话的数量)

分割方向消息 divide_msgs_by_directionlist

    @staticmethod
    def divide_msgs_by_directionlist(messages, direction_list):
        messages_request = list()
        messages_response = list()
        for i in range(len(direction_list)):
            if direction_list[i] == 0:
                messages_request.append(messages[i])
            else:
                messages_response.append(messages[i])

        return messages_request,messages_response

调用静态方法,创造外部程序使用的接口。

获取用于进行请求响应分类的规范关键字 get_true_keyword

   # get the true keyword defined by the specification
    def get_true_keyword(self, message):
        if self.protocol_type == "dhcp":
            kw = message.data[242:243]
        elif self.protocol_type == "dnp3":
            kw = message.data[12:13]
        elif self.protocol_type == "ftp":
            kw = re.split(" |-|\r|\n", message.data.decode())[0]
        elif self.protocol_type ==  "icmp":
            kw = message.data[0:2]
        elif self.protocol_type == "modbus":
            kw = message.data[7:8]
        elif self.protocol_type == "ntp":
            kw = message.data[0] & 0x07
        elif self.protocol_type == "smb":
            kw = message.data[4+4]
        elif self.protocol_type == "smb2":
            kw = struct.unpack("<H", message.data[4+12:4+12+2])[0]
        elif self.protocol_type == "tftp":
            kw = message.data[0:2]
        elif self.protocol_type == "zeroaccess":
            kw = message.data[4:8]
        else:
            logging.error("The TestName is not given known method for detecting direction.")

        if type(kw).__name__ == "bytes":
            kw = str(kw.hex())
        return kw

netplier.py

用于实现对齐、约束计算和可能性推测等。

模块导入

import logging
import os

from netzob.Model.Vocabulary.Field import Field
from netzob.Model.Vocabulary.Types.Raw import Raw
#from netzob.all import *
#from netzob.Model.Vocabulary.Session import Session
#from netzob.Model.Vocabulary.Field import Field

from alignment import Alignment
from constraint.constraint import Constraint
from probabilistic_inference import ProbabilisticInference

Field:Netzob的一个功能函数,用法如下。

符号结构遵循指定预期字段序列的格式,例如,TCP段包含作为序列号和校验和的预期字段。字段具有固定大小或可变大小。字段可以类似地由子元素组成(例如有效载荷字段)。为了进行有效表示,将字段作为树的一部分,并将字段的符号当作树的根,树结构由子字段组成。因此,一个字段始终有一个父节点,这个父节点也可以代表其它字段或者作为树结构的根节点,字段可以有选择性的拥有子字段。代表字段的值可以由其定义域定义,它可以是一个简单的静态值,如ASCII或十进制,也可以是更复杂的值,如包括转换或编码过滤器和关系。

例如几个简单的字段示例:

​ f = Field(100):包含十进制值100的字段 f = Field(0b1000): 包含特定二进制的字段,即1000(十进制的8)

​ f = Field(Raw(nbBytes=(8, 9))):包含8位(1字节)原始值的字段 f = Field(Raw(‘\x00\x01\x02\x03’)):具有特定原始值的字段

​ f = Field(IPv4()):表示随机IPv4的字段 f = Field([Size(payloadField)]):值是payloadField大小的字段

例如几个可供选择的字段示例:

​ f = Field([10, ASCII(nbChars=(10, 11))]):表示十进制(10)或10个字符的ASCII的字段

​ f = Field([“netzob”, “zoby”]):代表两种不同ASCLL值得字段,例如’netzob‘,’zoby‘。

Raw:Netzob的一个功能函数,用法如下。

以字节表示得原始Netzob数据类型。

例如我们可以使用这种类型来解析任何2个字节的原始字段:

​ f = Field(Raw(nbBytes=2))

或具有特定值(默认为小字节序):

​ f = Field(Raw(‘\x01\x02\x03’))

字母表可选参数可用于限制可以参与域值的字节:

​ f = Field(Raw(nbBytes=100, alphabet=[“t”, “o”]))

Netplier类初始化

class NetPlier:
    def __init__(self, messages, direction_list=None, output_dir='tmp/', mode='ginsi', multithread=False):
        self.messages = messages
        self.direction_list = direction_list
        self.output_dir = output_dir
        self.mode = mode
        self.multithread = multithread

        if not os.path.exists(self.output_dir):
            logging.debug("Folder {0} doesn't exist".format(self.output_dir))
            os.makedirs(self.output_dir)

设置输出路径和初始mafft模式

从mafft结果中生成字段generate_fields_by_fieldsinfo

    # Generate fields from mafft results
    def generate_fields_by_fieldsinfo(self, filepath_fields_info):
        print("[++++++++] Generate fields")
        assert os.path.isfile(filepath_fields_info), "The fields info file doesn't exist"
        fid_list = list()#存放messages中动态不可变长字段的序号
        fields_result = list()        
        with open(filepath_fields_info) as f:
            line_list = f.readlines()
            for i, line in enumerate(line_list):
                typename, typesizemin, typesizemax, fieldtype = line.split()
                typeinfo = [typename, int(typesizemin), int(typesizemax)]
                fields_result.append(typeinfo)
                if fieldtype == 'D':
                    fid_list.append(i)
        fields = self.generate_fields(fields_result)#存放每个字段允许的长度区间(以字节为单位,即:Field(Raw(nbBytes=(typeinfo[1]//8, typeinfo[2]//8))))
        logging.debug("Number of fields: {0}".format(len(fields)))
        return fields, fid_list

比如 RAW 0 32 ,则fields中存储的格式为fields.domain.dataType.typeName=Raw; fields.domain.dataType.size[0]=0; fields.domain.dataType.size[1]= 32

获取mafft排列后的每个字段长度信息generate_fields

    ## Generate fields
    def generate_fields(self, fields_result):
        fields = list()
        for typeinfo in fields_result:
            if typeinfo[0] == "Raw":
                field = Field(Raw(nbBytes=(typeinfo[1]//8, typeinfo[2]//8)))
                fields.append(field)
            else:
                logging.error("Field type is not Raw")

        return fields

netplier执行过程execute

    def execute(self):
        
        # Alignment
        # TODO: choose mode automatically
        msa = Alignment(messages=self.messages, output_dir=self.output_dir, mode=self.mode, multithread=self.multithread)
        #msa = Alignment(messages=self.messages, output_dir=self.output_dir, multithread=True)
        msa.execute()
        # exit()
        
        # Generate fields
        filepath_fields_info = os.path.join(self.output_dir, Alignment.FILENAME_FIELDS_INFO)
        self.fields, fid_list = self.generate_fields_by_fieldsinfo(filepath_fields_info)
        logging.debug("Number of keyword candidates: {}\nfid: {}".format(len(fid_list), fid_list))

首先对消息进行多序列对比进行对齐,得到一系列输出文件,然后通过generate_fields_by_fieldsinfo函数对信息提取后的文件操作得出候选关键字段的序号

        # Compute probabilities of observation constraints
        constraint = Constraint(messages=self.messages, direction_list=self.direction_list, fields=self.fields, fid_list=fid_list, output_dir=self.output_dir)
        
        pairs_p, pairs_size = constraint.compute_observation_probabilities()
        pairs_p_request, pairs_p_response = pairs_p
        pairs_size_request, pairs_size_response = pairs_size
        constraint.save_observation_probabilities(pairs_p_request, pairs_size_request, Constraint.TEST_TYPE_REQUEST)
        constraint.save_observation_probabilities(pairs_p_response, pairs_size_response, Constraint.TEST_TYPE_RESPONSE)

分别计算请求端和响应端的观察约束,并将请求端的标记为零存入prob_request.txt中,将响应端的标记为1存入prob_response.txt中

        # Probabilistic inference
        pairs_p_all, pairs_size_all = self.merge_constraint_results(pairs_p_request, pairs_p_response, pairs_size_request, pairs_size_response)

        ffid_list = ["{0}-{0}".format(fid) for fid in fid_list] #only test same fid for both sides
        pi = ProbabilisticInference(pairs_p=pairs_p_request, pairs_size=pairs_size_request)
        fid_inferred = pi.execute(ffid_list)
        
        ## TODO: iterative
        ## TODO: format inference
        
        return fid_inferred

通过概率推测得到推荐的关键字段(只测试两侧相同的字段),假如是dnp3,则ffid_list: [‘1-1’, ‘2-2’, ‘3-3’, ‘5-5’, ‘7-7’, ‘8-8’, ‘9-9’, ‘10-10’, ‘11-11’, ‘12-12’, ‘13-13’]


alignment.py

这个文件将预处理后的消息列表进行了对齐,并生成了关键字候选字段列表。

模块引入

import subprocess
import os
import logging
import copy

subprocess:用于终端模式下与子进程交互

​ 学习链接 https://blog.csdn.net/qq_37674086/article/details/84983843

Alignment类的初始化

class Alignment:
    FILENAME_INPUT = "msa_input.fa"
    FILENAME_OUTPUT = "msa_output.txt"
    FILENAME_OUTPUT_ONELINE = "msa_output_oneline.txt"
    FILENAME_FIELDS_INFO = "msa_fields_info.txt"
    FILENAME_FIELDS_VISUAL = "msa_fields_visual.txt"

    def __init__(self, messages, output_dir='tmp/', mode='ginsi', multithread=False, ep=0.123):
        self.messages = messages
        self.output_dir = output_dir
        self.mode = mode
        self.multithread = multithread
        self.ep = ep
        '''
        self.nthread = nthread
        self.nthreadtb = nthreadtb
        self.nthreadit = nthreadit
        '''

        self.filepath_input = os.path.join(self.output_dir, Alignment.FILENAME_INPUT)
        self.filepath_output = os.path.join(self.output_dir, Alignment.FILENAME_OUTPUT)
        self.filepath_output_oneline = os.path.join(self.output_dir, Alignment.FILENAME_OUTPUT_ONELINE)
        self.filepath_fields_info = os.path.join(self.output_dir, Alignment.FILENAME_FIELDS_INFO)
        self.filepath_fields_visual = os.path.join(self.output_dir, Alignment.FILENAME_FIELDS_VISUAL)

规定了一些中间过程文件:

msa_input.fa:存放用于mafft输入的信息,十六进制,字节之间用“~”分割

msa_output.txt:存放mafft对齐后的结果,,相同字段的缺省位置用“-”填充

msa_output_oneline.txt:将对齐后的每条消息按行整理存入此处

msa_fields_info.txt:存放经过分析后每个字段的字段信息

msa_fields_visual.txt:存放可视化后的字段信息

执行过程execute

    def execute(self):
        ## Generate msa input (with tilde)
        self.create_mafft_input_with_tilde()
        ## Execute Mafft
        self.execute_mafft()
        ## Change to oneline
        self.change_to_oneline()
        ## Remove tilde
        self.remove_character(self.filepath_output_oneline)
        ## Analyze fields
        self.generate_fields_info(self.filepath_output_oneline)
        self.generate_fields_visual_from_fieldsinfo()

生成MSA输入->执行mafft对齐策略->按行整理排序结果->删除"~"符号->提取字段信息->生成可视化信息

建立mafft输入文件,并在字节之间加入"~"符号 create_mafft_input_with_tilde

    # hex, add "~" after each byte
    def create_mafft_input_with_tilde(self):
        message_data_hex = list()
        for message in self.messages:
            message_data_hex.append(message.data.hex())

        with open(self.filepath_input, 'w') as f:
            for i, message in enumerate(message_data_hex):
                message_space = '~'.join(message[j:j+2] for j in range(0, len(message), 2))
                f.write(">{0}\n{1}\n".format(i, message_space))

执行mafft对齐过程 execute_mafft

    def execute_mafft(self):
        print("[++++++++] Execute Alignment")

        assert self.mode in ["ginsi", "linsi", "einsi"], "the mafft mode should be ginsi, linsi, or einsi"

        if not self.multithread:
            cmd = f"mafft-{self.mode} --inputorder --text --ep {self.ep} --quiet {self.filepath_input} > {self.filepath_output}"
        else:
            cmd = f"mafft-{self.mode} --thread -1 --inputorder --text --ep {self.ep} --quiet {self.filepath_input} > {self.filepath_output}"
            #cmd = f"mafft-{self.mode} --thread {self.nthread} --threadtb {self.nthreadtb} --threadit {self.nthreadit} --inputorder --text --ep {self.ep} {self.filepath_input} > {self.filepath_output}"
        logging.debug("mafft cmd: {}".format(cmd))
        
        #run mafft
        result = subprocess.check_output(cmd, shell=True)#开始运行

mafft分为三种对齐模式:“ginsi” , “linsi” , “einsi”,对三种模式进行判断

​ mafft学习网址:https://mafft.cbrc.jp/alignment/software/algorithms/algorithms.html

然后进行mafft的线程选择,对齐之后输出文件为msa_output.txt。

按行整理对齐结果 change_to_oneline

    ## process alignment results files
    def change_to_oneline(self):
        logging.debug("[+] Change to oneline")

        assert os.path.isfile(self.filepath_output), "The msa output file doesn't exist"

        isfirstline = True
        with open(self.filepath_output) as f:
            with open(self.filepath_output_oneline, 'w') as fout:
                for line in f.read().splitlines():
                    if line.startswith('>'):
                        if isfirstline:
                            isfirstline = False
                        else:
                            fout.write("\n")
                    else:
                        fout.write("{0}".format(line))

首先判断写入文件是否存在,然后对每行进行判断,如果以 ‘>’ 开头就换行写入。

删除“~”符号 remove_character

    def remove_character(self, filepath):
        logging.debug("[+] Remove character")

        assert os.path.isfile(filepath), "The file doesn't exist: {}".format(filepath)

        with open(filepath) as f:
            linelist = f.read().splitlines()

        results = [list() for i in range(len(linelist))]

        for i in range(len(linelist[0])):
            isToDelete = True
            for line in linelist:
                if line[i] != '-' and line[i] != '~':
                    isToDelete = False
                    break
            if not isToDelete:
                for j,line in enumerate(linelist):
                    #print("{0} {1}".format(i, j))
                    results[j].append(line[i])

        with open(filepath, 'w') as fout:
            for line in results:
                fout.write("{0}\n".format(''.join(line)))

以列为单位向右扫描,将存在的"~"符号删除。

判断列表里每项位数是否为偶 has_even_number_of_bytes

    def has_even_number_of_bytes(self, valuelist):
        for value in valuelist:
            value_string = ''.join(value)
            if len(value_string.replace("-", "")) % 2 != 0:
                return False
        return True

单拎出来列表valuelist中每一项进行判断,位数是偶数才判断正确,理论上每项的位数也是相同的。

判断是否为可变字段 is_variable_field

    def is_variable_field(self, valuelist):
        for value in valuelist:
            if '-' in value:
                return True
        return False

对于valuelist中每一个字段,如果存在’-‘,则判定为可变字段。

生成字段信息fields_info generate_fields_info

    def generate_fields_info(self, filepath_input):
        logging.debug("[+] Generate fields info")
        
        assert os.path.isfile(filepath_input), "The file doesn't exist: {}".format(filepath_input)

        with open(filepath_input) as f:
            linelist = f.read().splitlines()

        length_message = len(linelist[0])

        ## Only record fields info
        results_fields = list()

        i = 0
        isLastStatic = False

判断输入文件是否存在,将输入文件的每行存入列表linelist中,将每行的位数计入length_message变量中,用results_fields记录字段信息。

        while i < length_message:
            offset = 2
            while i + offset <= length_message:
                valuelist = [line[i:i+offset] for line in linelist]
                if not self.has_even_number_of_bytes(valuelist):#判断valuelist中每项位数是否为偶数
                    offset += 1  
                    continue
                else:
                    break

假设一个linelist为[‘asdawdasdawsdawdasd’, ‘sdawdasdawfassadawd’, ‘asdffawfafhftjfgjfj’],则valuelist = [line[0:0+2] for line in linelist]为[‘as’, ‘sd’, ‘as’],此时has_even_number_of_bytes(valuelist)=’true‘.

这里是为了判断valuelist中以列为单位向后遍历时每项位数是否为偶数,在真实报文中,组成字段的最小单位为字节,每个字节由两位组成,因此奇数位的列表项是不合理的。

            if not len(set(valuelist)) == 1: #set是指去掉重复项后组成的新列表
                if self.is_variable_field(valuelist):
                    fields_info = [offset, 'V']          
                else:
                    fields_info = [offset, 'D']
                results_fields.append(fields_info)
                isLastStatic = False

set(valuelist): 是指将列表中重复项给去掉。

​ 学习网址:https://www.cnblogs.com/feixiangtaiyang/p/14591572.html

这里将得到的列表进行验证,如果去掉重复项后组成的新列表项数不为1,则说明不是静态字段。其中分为两种情况,如果存在’-‘,则说明是可变长字段,记为’V‘;否则为固定长度字段,记为’D‘。最后将该字段最后一位的位置和字段信息计入results_fields中。

            else:
                if isLastStatic:
                    results_fields[-1][0] += offset
                else:
                    fields_info = [offset, 'S']
                    results_fields.append(fields_info)
                isLastStatic = True

            i = i + offset
        logging.debug("Number of fields: {0}".format(len(results_fields)))

这段是为了提取静态字段,当去掉重复项后组成的新列表项数为1时,说明是个静态字段。isLastStatic的作用是向右遍历并判断,使用贪婪思想提取静态字段,记为’S‘,同样将静态字段最后一位的位置和字段信息计入results_fields中。

当判断完一个字段后,offset重新记录下一个字段的长度,也就是最终offset的总和等于每个消息长度。

        with open(self.filepath_fields_info, 'w') as fout:
            for fields_info in results_fields:
                fout.write("Raw 0 {0} {1}\n".format(fields_info[0]*8, fields_info[1]))

将字段信息存入msa_fields_info.txt中,并转化为以二进制位数计数方法。

例:

​ 假设得到的results_fields为[[2, ‘V’], [3, ‘S’], [6, ‘D’], [10, ‘D’], [20, ‘S’]],则记录进入Raw 0 16 v\n,Raw 0 24 S\n,Raw 0 48 D\n,Raw 0 80 D\n,Raw 0 160 S\n

得到字段信息 get_fields_info

    def get_fields_info(self):
        assert os.path.isfile(self.filepath_fields_info), "The fields info file doesn't exist"

        fields_info = dict() #pos:type
        pos = 0
        with open(self.filepath_fields_info) as f:
            line_list = f.readlines()
            for i, line in enumerate(line_list):
                typename, typesizemin, typesizemax, fieldtype = line.split()
                pos += int(typesizemax) // 8
                fields_info[pos] = fieldtype

        return fields_info

这里用来获取上面分析得到的字段信息,为接下来生成可视化字段做准备。

对程序for循环其中一步举例说明:

​ 将Raw 0 24 S分为Raw(‘typename’),0 (‘typesizemin’),24 (‘typesizemax’),S (‘fieldtype’)。然后以整条信息长度为基准,以字节为单位,pos记录字段信息在该条信息中的位置。在上一模块例子中fields_info[5] = ’S‘。

从fields_info中生成可视化字段 generate_fields_visual_from_fieldsinfo

    def generate_fields_visual_from_fieldsinfo(self):
        ## get fileds_info
        fields_info = self.get_fields_info()
        #print(fields_info)

        assert os.path.isfile(self.filepath_output_oneline), "The msa output oneline file doesn't exist"
        with open(self.filepath_output_oneline) as f:
            messages_data_mafft = f.read().splitlines()

messages_data_mafft是对齐后的每条消息

        with open(self.filepath_fields_visual, 'w') as fout:
            for messages_data in messages_data_mafft:
                fields_value = list()
                pos_list = sorted(list(fields_info.keys()))
                pos_start = 0
                for i,pos_end in enumerate(pos_list):
                    fields_value.append(messages_data[pos_start:pos_end])#一个字段
                    pos_start = pos_end
                fields_value.append(messages_data[pos_start:])
                #到这里fields_value每一项存放着一个字段
                fout.write("{0}\n".format(' '.join(fields_value)))
                #将信息分开存放进去,方便后续进行关键字段推理

list(fields_info.keys())是指fields_info字典中有数据位置的下标标号。post_list记录了每项消息从左到右划分好的字段的信息。

静态调用 get_messages_aligned

    @staticmethod
    def get_messages_aligned(messages, filepath_output_oneline):
        assert os.path.isfile(filepath_output_oneline), "The msa output oneline file doesn't exist"

        messages_aligned = copy.deepcopy(messages)
        with open(filepath_output_oneline) as f:
            messages_aligned_data = f.read().splitlines()

        for i in range(len(messages_aligned)):
            messages_aligned[i].data = messages_aligned_data[i]

        return messages_aligned

这个是外部调用,将filepath_output_oneline中的数据传给messages_aligned。


message_similarity.py

这个部分是计算先验概率中消息相似性约束的

MessageSimilarity类初始化

import logging

class MessageSimilarity:

    def __init__(self, messages):
        self.messages = messages
        self.similarity_matrix = list()

输入为消息列表,另外还定义了一个相似矩阵列表。

计算两个已对齐消息的相似性分数s compute_similarity_scores_by_alignment

    def compute_similarity_scores_by_alignment(self, msgdata1, msgdata2):
        if len(msgdata1) != len(msgdata2):
            logging.error("The two compared messages don't have same length.")
            return -2
        # TODO: use NW to get more accurate score
        result = [1 for i in range(len(msgdata1)) if msgdata1[i]==msgdata2[i]]
        score = sum(result)/len(msgdata1)
        return score

在这里面s = 相同字节数/两条消息的字节总数

感觉score的计算方式有问题,每个 msgdata1[i] 只能代表半字节,计算结果相差很大

构造相似度得分矩阵 compute_similarity_matrix

    def compute_similarity_matrix(self):
        print("[++++] Compute matrix of similarity scores")
        scoreslist = list()
        for i in range(len(self.messages)):
            initial_scores_list = [-1 for i in range(len(self.messages))]
            scoreslist.append(initial_scores_list)

设得分矩阵为scoreslist,对scoreslist进行初始化。

        # use the MSA result is quick, but less accurate
        for i in range(len(self.messages)):
            for j in range(i, len(self.messages)):
                if j == i:
                    score = 100.0
                    scoreslist[i][j] = score
                else:
                    score = self.compute_similarity_scores_by_alignment(self.messages[i].data, self.messages[j].data)
                    scoreslist[i][j] = score
                    scoreslist[j][i] = score 
        
        self.similarity_matrix = scoreslist

通过此步得到一个相似度得分矩阵,这个应用的时候有在文件里面输出

例:一个消息聚类有3条消息,那么相似度得分矩阵初始是-1 -1 -1,运算后可以是

100 92 96

92 100 88

96 88 100

计算类内相似分数和类间相似分数compute_inner_inter_scores

    def compute_inner_inter_scores(self, symbols):
        logging.debug("[+] Compute Inner/Inter Scores")

        #dict_mid_i存储每条消息的序号
        dict_mid_i = dict()
        for i,message in enumerate(self.messages):
            dict_mid_i[message.id] = i
        

        inner_inter_scores = dict()

        for s in symbols.values():
            sn = str(s.name)
            #这个sn指的就是符号,比如'14'
            #0: message num
            #1: inner scores list
            #2: inter scores list
            inner_inter_scores[sn] = list()
            # TODO: message num is not used
            
            mi_list = [dict_mid_i[message.id] for message in s.messages]
            #mi_list存储的是该符号所拥有的消息序号,例如当sn为'14'时,mi_list=[0, 3, 9, 12, 23, 26, 38, 41]
            inner_inter_scores[sn].append(mi_list) #0: message num
            #第一次循环即inner_inter_scores['14']=[[0, 3, 9, 12, 23, 26, 38, 41]]
            
            inner_score_list, inter_score_list = list(), list()
            #通过compute_similarity_matrix得分矩阵,得出类内两两消息相似度的列表
            for i in range(len(mi_list)):
                for j in range(i + 1, len(mi_list)):
                    inner_score_list.append(self.similarity_matrix[mi_list[i]][mi_list[j]])
            #通过compute_similarity_matrix得分矩阵,得出类间两两消息相似度的列表    
            for i in mi_list:
                for j in range(len(self.messages)):
                    if j not in mi_list:
                        inter_score_list.append(self.similarity_matrix[i][j])
            #将分数降序排列并放入inner_inter_scores[‘14’]中
            inner_inter_scores[sn].append(sorted(inner_score_list, reverse=True))
            inner_inter_scores[sn].append(sorted(inter_score_list, reverse=True))
            
        return inner_inter_scores

以上都以symbols[‘14’]=[0, 3, 9, 12, 23, 26, 38, 41]为例,最后会输出一个inner_inter_scores的字典列表,对于字典符号为’xx’的列表,共包含三项,分别为:该符号所包含的所有消息序号列表、这些消息的类内分数降序排序列表、这些消息的类间分数降序排序列表。整体数据格式也就是字典符号{列表{列表}}

计算FNMR compute_fnmr

    def compute_fnmrs(self, scores):
        scores.sort()#对类内分数列表进行升序排序
        numGM = len(scores)
        t_fnmr_list = list()

        # first one: [0, 0]
        result = [0, 0]
        t_fnmr_list.append(result)#当t取0的时候,fnmr必为0
        
        t = -1
        for i in range(0, numGM): #这个循环是为了对fnmr在取不同t值的时候分别计算
            if scores[i] > t :
                if (t != -1):
                    fnmr = i / numGM #i-1+1 / numGM
                    result = [t, fnmr]  #从第一项开始,设t为不同的score[i],此时计算fnmr
                    t_fnmr_list.append(result)
                t = scores[i]
        result = [scores[i], 1]
        t_fnmr_list.append(result) #当t取scores列表最后一位时的时候,fnmr必为1

        # last one: [1, 1]
        result = [1, 1]
        t_fnmr_list.append(result)#当t取1的时候,fnmr必为0

        return t_fnmr_list

最后输出的t_fnmr_list是一个二维列表,每项的第一位代表取的t值,第二个代表fnmr值,例如对于聚类符号为’14’的来说t_fnmr_list为: [[0, 0], [0.9897435897435898, 0.5714285714285714], [1.0, 1], [1, 1]]

计算fmr compute_fmrs

    def compute_fmrs(self, scores):
        scores.sort()
        numIM = len(scores)
        t_fmr_list = list()

        # first one: [0, 1]
        result = [0, 1]
        t_fmr_list.append(result)
        
        t = -1
        for i in range(0, numIM):
            if scores[i] > t :
                if (t != -1):
                    fmr = (numIM - i) / numIM #这个是取大于t的个数来计算。
                    result = [t, fmr]
                    t_fmr_list.append(result)
                t = scores[i]
        result = [scores[i], 0]
        t_fmr_list.append(result)

        # last one: [1, 0]
        result = [1, 0]
        t_fmr_list.append(result)

        return t_fmr_list

原理和输出数据格式同compute_fnmr一样。

计算EER compute_eer

    def compute_eer(self, inner_scores, inter_scores):
        #tfnmr = stat_scores(inner_score_list)
        #tfmr = stat_scores(inter_score_list)
        if len(inner_scores) == 0 or len(inter_scores) == 0:
            return 1 # 0.05

        t_fnmr_list = self.compute_fnmrs(inner_scores)
        t_fmr_list = self.compute_fmrs(inter_scores)

        tfnmrlist = [x[0] for x in t_fnmr_list]
        fnmrlist = [x[1] for x in t_fnmr_list]
        tfmrlist = [x[0] for x in t_fmr_list]
        fmrlist = [x[1] for x in t_fmr_list]

        ifnmr = 0
        ifmr = 0
        fnmr1 = 0.0
        fnmr2 = 0.0
        fmr1 = 1.0
        fmr2 = 1.0
        tfnmr1 = 0.0
        tfnmr2 = 0.0
        tfmr1 = 0.0
        tfmr2 = 0.0
        while True:
            if fmr2 <= fnmr2:
                break
            if  tfmr2 < tfnmr2:
                ifmr += 1
                tfmr1 = tfmr2
                fmr1 = fmr2
                tfmr2 = tfmrlist[ifmr]
                fmr2 = fmrlist[ifmr]
            elif tfmr2 > tfnmr2:
                ifnmr += 1
                tfnmr1 = tfnmr2
                fnmr1 = fnmr2
                tfnmr2 = tfnmrlist[ifnmr]
                fnmr2 = fnmrlist[ifnmr]
            else:
                ifmr += 1
                tfmr1 = tfmr2
                fmr1 = fmr2
                tfmr2 = tfmrlist[ifmr]
                fmr2 = fmrlist[ifmr]

                ifnmr += 1
                tfnmr1 = tfnmr2
                fnmr1 = fnmr2
                tfnmr2 = tfnmrlist[ifnmr]
                fnmr2 = fnmrlist[ifnmr]
        #print("FMR: t1=%s,fmr=%s; t2=%s,fmr=%s" % (tfmr1,fmr1,tfmr2,fmr2))
        #print("FNMR: t1=%s,fnmr=%s; t2=%s,fnmr=%s" % (tfnmr1,fnmr1,tfnmr2,fnmr2))
        
        if fmr2 == fnmr2:
            eer = fmr2
            t = min(tfmr2,tfnmr2)
            #print("EER: %s, t: %s" %(eer,t))
        else:
            l = max(fnmr1,fmr2)
            h = min(fnmr2,fmr1)
            eer = (l+h)/2
            t1 = max(tfmr1,tfnmr1)
            t2 = min(tfmr2,tfnmr2)
            t = (t1+t2)/2
            #print("l=%s, h=%s; t1=%s, t2=%s" %(l,h,t1,t2))
            #print("EER: %s, t: %s" %(eer,t))
        return eer

通过一系列数学算法(没细究),得出来fnmr和fmr函数交界处的EER值。

计算Pm compute_similarity_constraints

def compute_similarity_constraints(self, inner_inter_scores):
        symbol_m = {}
        for key,values in inner_inter_scores.items():
            symbol_m[key] = 1 - self.compute_eer(values[1], values[2])
        return symbol_m

计算每个符号(即聚类)对应的pm的大小,即symbol_m[‘14’]: 0.5476190476190477;symbol_m[‘08’]: 0.6615646258503401等。

归拢Pm compute_constraint_message_similarity

    # compute p_m
    def compute_constraint_message_similarity(self, symbols):
        logging.debug("[+] Compute observation probabilities of message similarity")
        sn_list = [str(s.name) for s in symbols.values()] #这里指候选关键字段每个符号的名字

        inner_inter_scores = self.compute_inner_inter_scores(symbols)
        symbol_m = self.compute_similarity_constraints(inner_inter_scores)

        p_m = list()
        for s in sn_list:#这里将pm归拢,按顺序排列到p_m列表里
            if symbol_m[s] > 0: #!= -1:
                p_m.append(symbol_m[s])
                #p_m.append(dict_scores_test[s] * dict_sn_msgnum_test[s] / msgnum_total_test)
            elif len(sn_list) == 1: # no inter scores 
                p_m.append(-2)
            else: # no inner scores
                p_m.append(-1) # TODO: may not need it

        return p_m

这个最后输出的是每个聚类的消息相似性约束组成的列表,例如,当dnp3以fid=1为关键字段进行划分时,得到的四个聚类的消息相似性约束为p_m[0.5476190476190477, 0.6615646258503401, 0.04081632653061229, 0.4707207207207207]。


remote_coupling.py

这个计算远程耦合约束的

模块引入

import copy
import logging

from netzob.Model.Vocabulary.Session import Session

Session: 引入netzob的功能来检索嵌入当前会话中的真实会话。可以得到会话数量和每次会话的参与方

RemoteCoupling类初始化

class RemoteCoupling:
    TEST_TYPE_REQUEST = 0
    TEST_TYPE_RESPONSE = 1

    def __init__(self, messages_all, symbols_request, symbols_response, direction_list):
        self.messages_all = messages_all
        self.symbols_request = symbols_request
        self.symbols_response = symbols_response
        self.direction_list = direction_list

        self.pairs_request = dict()
        self.pairs_response = dict()

**messages_all:**已经排列好的消息

**symbols_request:**请求端的聚类结果

symbols_response: 响应端的聚类结果

**direction_list:**每个消息的方向列表

**pairs_request:**请求端到响应端的消息对

**pairs_response:**响应端到请求端的消息对

使用方向列表检查它是否是有效的会话 compute_pairs_by_directionlist

    def compute_pairs_by_directionlist(self):
        logging.debug("[+] Compute request/respnse pairs info")
        #计算请求/响应消息对
        symbolList_request = list(self.symbols_request.values())
        symbolList_response = list(self.symbols_response.values())
        symbolNameList_request = [str(s.name) for s in self.symbols_request.values()]
        symbolNameList_response = [str(s.name) for s in self.symbols_response.values()]

**symbolNameList_request:**列出请求端的聚类符号名称

        # generate new messages
        messages = copy.deepcopy(self.messages_all)
        sessions = Session(messages)
        # lenofSession = len(sessions.getTrueSessions())
        # print("lenth of session: {0}".format(lenofSession))

这里就引入了模块里面的Session。可用以下方式获取信息:

会话数量:

lenofSession = len(sessions.getTrueSessions()) 例如4

会话具体参与方:

for trueSession in session.getTrueSessions():
print(trueSession.name)

例如:

Session: ‘192.168.0.197:20000’ - ‘192.168.0.198:36639’

Session: ‘192.168.0.197:20000’ - ‘192.168.0.198:43479’

Session: ‘192.168.0.197:20000’ - ‘192.168.0.198:55748’

Session: ‘192.168.0.197:20000’ - ‘192.168.0.198:51955’

        dict_mid_sn = dict()
        for s in self.symbols_request.values():
            sn = str(s.name)
            for message in s.messages:
                dict_mid_sn[message.id] = sn
        for s in self.symbols_response.values():
            sn = str(s.name)
            for message in s.messages:
                dict_mid_sn[message.id] = sn

        for i in range(len(self.direction_list)):
            data = [dict_mid_sn[messages[i].id], self.direction_list[i]]
            messages[i].data = data

**message.id:**很奇怪的一串数字,不知道怎么生成的?暂用做消息的ID号码

dict_mid_sn[message.id]: 字典索引就是每条消息的id号,每个索引存储的内容是该条消息所对应的候选字段符号名称。

**messages[i].data:**存放两个数据,即消息messages[i]所对应的当前候选字段符号名称、消息所对应的方向(请求端还是回应端。)

        # count pair info
        dict_request, dict_response = dict(), dict()
        for sn in symbolNameList_request:
            dict_request[sn] = dict()
        for sn in symbolNameList_response:
            dict_response[sn] = dict()

以请求端和回应端的字段符号为索引分别创建字典dict_request[sn]、dict_response[sn]。

        # TODO: improve
        for session in sessions.getTrueSessions(): #一个会话里面循环,一个会话有很多条消息
            messages_list = list(session.messages.values())
            messages_list = sorted(messages_list, key=lambda x:x.date)

            #Check if it is invalid (the first is request)
            '''
            if messages_list[0].data[1] != 0:
                continue
            '''
            '''
            # Check if it is invalid (only request)
            srcIP_list = [message.source for message in messages_list]
            if len(set(srcIP_list)) == 1:
                #print("This session is invalid.")
                continue
            '''

**messages_list:**是netzob中用session封装的信息列表,可用以下方式进行输出:

输出时间:

for message in messages_list:
f.write(“messages_listdate:{0}\n”.format(message.date)

输出数据:

for message in messages_list:
f.write(“messages_listdata:{0}\n”.format(message.data)) //这个发现输出的是当前候选关键字段的符号名称和所属的方向列表

# Find the first request msg #找到每个会话的第一条请求消息
            i_first_request_msg = -1
            for i,message in enumerate(messages_list):
                if message.data[1] == 0:
                    i_first_request_msg = i #在将消息升序排序后,只需要将第一个方向列表为0消息的序号记录在i_first_request_msg中即可
                    break
            if i_first_request_msg == -1:
                continue
            #requestSrcIP = str(messages_list[0].source)
            preRequestS = None  
            for message in messages_list[i_first_request_msg:]:
                sn = message.data[0]
                if message.data[1] == 0:
                    preRequestS = sn
                else:
                    if sn in dict_request[preRequestS]:
                        dict_request[preRequestS][sn] += 1
                    else:
                        dict_request[preRequestS][sn] = 1
                    if preRequestS in dict_response[sn]:
                        dict_response[sn][preRequestS] += 1
                    else:
                        dict_response[sn][preRequestS] = 1

计算双方的通讯配对次数(具体数据结构没看懂

        # compute pairs constraints results
        # method 1: use the lenth
        # method 2: use the proportion of the larger one
        for s in symbolNameList_request:
            # method 1
            # self.pairs_request[s] = 1 / len(dict_request[s])
            # method 2
            list_msgcount= sorted(dict_request[s].items(),key=lambda x:x[1],reverse=True)
            # print(list_msgcount)
            count_total = 0
            for item in list_msgcount:
                count_total += item[1]
            if len(list_msgcount) > 0:
                self.pairs_request[s] = list_msgcount[0][1] / count_total
            else:
                self.pairs_request[s] = 0
        for s in symbolNameList_response:
            # method 1
            #self.pairs_response[s] = 1 / len(dict_response[s])
            # method 2
            list_msgcount= sorted(dict_response[s].items(),key=lambda x:x[1],reverse=True)
            count_total = 0
            for item in list_msgcount:
                count_total += item[1]
            if len(list_msgcount) > 0:
                self.pairs_response[s] = list_msgcount[0][1] / count_total
            else:
                self.pairs_response[s] = 0

计算远程耦合约束,对于请求端和客户端每个聚类的Pr,分别记录在了self.pairs_request,self.pairs_response中

例如输出结果:

self.pairs_request:dict_items([(‘14’, 0.5), (‘08’, 0.5), (‘11’, 1.0), (‘0b’, 1.0)])
self.pairs_response:dict_items([(‘0a’, 0.9024390243902439), (‘ff’, 0.5), (‘1f’, 0.5)])

    def compute_constraint_remote_coupling(self, direction):
        test_type = "request" if direction == RemoteCoupling.TEST_TYPE_REQUEST else "response"
        logging.debug("[+] Compute observation probabilities of remote coupling: {}".format(test_type))
        
        symbols = self.symbols_request if direction == RemoteCoupling.TEST_TYPE_REQUEST else self.symbols_response
        pairs = self.pairs_request if direction == RemoteCoupling.TEST_TYPE_REQUEST else self.pairs_response

        sn_list = [str(s.name) for s in symbols.values()]
        p_r = list()
        for s in sn_list:
            if pairs[s] > 0:
                p_r.append(pairs[s])
            else:
                p_r.append(-1)

        return p_r

通过方向列表将所得到的p_r按照请求端和客户端分开,在constraint中引用,例如:

p_r_request:[0.5, 0.5, 1.0, 1.0]
p_r_response:[0.9024390243902439, 0.5, 0.5]


constraint.py

这部分对候选关键字段的四个约束进行概率推理

模块引入

import os
import logging
import copy
import collections
import gc

from netzob.Model.Vocabulary.Symbol import Symbol
from netzob.Model.Vocabulary.Field import Field
from netzob.Model.Vocabulary.Types.Raw import Raw 
#from netzob.Import.PCAPImporter.all import *
#from netzob.Model.Vocabulary.Session import Session

from processing import Processing
from alignment import Alignment
from constraint.message_similarity import MessageSimilarity
from constraint.remote_coupling import RemoteCoupling

collections:python的内置模块,实现了特定目标的容器,以提供Python标准内建容器 dict , list , set , 和 tuple 的替代选择。

​ 学习链接:collections — 容器数据类型 — Python 3.11.2 文档

**Field:**Netzob的一个功能函数,用法如下。

符号结构遵循指定预期字段序列的格式,例如,TCP段包含作为序列号和校验和的预期字段。字段具有固定大小或可变大小。字段可以类似地由子元素组成(例如有效载荷字段)。为了进行有效表示,将字段作为树的一部分,并将字段的符号当作树的根,树结构由子字段组成。因此,一个字段始终有一个父节点,这个父节点也可以代表其它字段或者作为树结构的根节点,字段可以有选择性的拥有子字段。代表字段的值可以由其定义域定义,它可以是一个简单的静态值,如ASCII或十进制,也可以是更复杂的值,如包括转换或编码过滤器和关系。

例如几个简单的字段示例:

​ f = Field(100):包含十进制值100的字段 f = Field(0b1000): 包含特定二进制的字段,即1000(十进制的8)

​ f = Field(Raw(nbBytes=(8, 9))):包含8位(1字节)原始值的字段 f = Field(Raw(‘\x00\x01\x02\x03’)):具有特定原始值的字段

​ f = Field(IPv4()):表示随机IPv4的字段 f = Field([Size(payloadField)]):值是payloadField大小的字段

例如几个可供选择的字段示例:

​ f = Field([10, ASCII(nbChars=(10, 11))]):表示十进制(10)或10个字符的ASCII的字段

​ f = Field([“netzob”, “zoby”]):代表两种不同ASCLL值得字段,例如’netzob‘,’zoby‘。

Raw:Netzob的一个功能函数,用法如下。

以字节表示得原始Netzob数据类型。

例如我们可以使用这种类型来解析任何2个字节的原始字段:

​ f = Field(Raw(nbBytes=2))

或具有特定值(默认为小字节序):

​ f = Field(Raw(‘\x01\x02\x03’))

字母表可选参数可用于限制可以参与域值的字节:

​ f = Field(Raw(nbBytes=100, alphabet=[“t”, “o”]))

Symbol:Netzob的一个功能函数,用法如下。

符号表示一组消息的通用抽象,我们可以基于两条原始消息创建一个符号。

例1:

​ 假设有两条消息m1=“hello world”,m2=“hello earth”。

​ 则可以另fields = [Field("hello ", name=“f0”), Field([“world”, “earth”], name=“f1”)];

​ symbol = Symbol(fields, messages=[m1, m2])

​ 在fields字段中,有一个名为’f0‘的固定值’hello‘和一个名为’f1‘的可供选择值[“world”, “earth”],symbol语法刚好概括了假设中的两条消息。

​ symbol的值为 f0 | f1

​ --------- | -------

​ 'hello ’ | ‘world’

​ 'hello ’ | ‘earth’

例2:

​ s = Symbol([Field("hello ", name=“f0”), Field(ASCII(nbChars=(0, 10)), name=“f1”)])

​ s.messages.append(RawMessage(“hello toto”))

​ s的值为 f0 | f1

​ --------- | -------

​ 'hello ’ | ‘toto’

Constraint类初始化

class Constraint:
    TEST_TYPE_REQUEST = 0
    TEST_TYPE_RESPONSE = 1
    #FILENAME_P_REQUEST = "prob_request.txt"
    #FILENAME_P_RESPONSE = "prob_response.txt"

    def __init__(self, messages, direction_list, fields, fid_list, output_dir='tmp/'):
        self.messages = messages
        self.direction_list = direction_list
        self.fields = fields
        self.fid_list = fid_list
        self.output_dir = output_dir

计算先验概率 compute_observation_probabilities

    def compute_observation_probabilities(self):#计算先验概率
        print("[++++++++] Compute probabilities of observation constraints")
        messages_aligned = Alignment.get_messages_aligned(self.messages, os.path.join(self.output_dir, Alignment.FILENAME_OUTPUT_ONELINE))#将排列好的每行数据存放在messages_aligned中
        messages_request, messages_response = Processing.divide_msgs_by_directionlist(self.messages, self.direction_list)
        messages_request_aligned, messages_response_aligned = Processing.divide_msgs_by_directionlist(messages_aligned, self.direction_list)#通过消息区分将排列好的数据分为请求端和响应端

        fid_list_request = self.filter_fields(self.fields, self.fid_list, messages_request_aligned)#用于得出请求端候选关键字段序号列表
        fid_list_response = self.filter_fields(self.fields, self.fid_list, messages_response_aligned)#用于得出响应端候选关键字段序号列表
        logging.debug("request candidate fid: {}\nresponse candidate fid: {}".format(fid_list_request, fid_list_response))

os.path.join()用来拼接路径

        # compute matrix of similarity scores
        constraint_m_request, constraint_m_response = MessageSimilarity(messages = messages_request_aligned), MessageSimilarity(messages = messages_response_aligned)
        constraint_m_request.compute_similarity_matrix()
        constraint_m_response.compute_similarity_matrix()

计算相似性得分矩阵,得到请求端和响应端的相似性得分矩阵,用constraint_m_request.similarity_matrix可以进行输出查看。

        # the observation prob of each cluster: {fid: the list of observation probabilities ([pm,ps,pd,pv])} 
        cluster_p_request, cluster_p_response = dict(), dict() 
        # the size of each cluster
        cluster_size_request, cluster_size_response = dict(), dict()
        # the observation prob of each cluster pair: {fid-fid: [,]}
        pairs_p_request, pairs_p_response = dict(), dict()
        pairs_size_request, pairs_size_response = dict(), dict()

cluster_p_request, cluster_p_response用于存放请求端和响应端的四个观察概率[pm,ps,pd,pv]的列表

cluster_size_request, cluster_size_response 用于存放请求端和响应端的每个聚类大小

后边两个存放聚类对的观察概率。

        for fid_request in fid_list_request:
            logging.info("[++++] Test Request Field {0}-*".format(fid_request))

            # merge other fields
            fields_merged_request = self.merge_nontest_fields(self.fields, fid_request)
            fid_merged_request = 0 if fid_request == 0 else 1

            # generate clusters
            symbols_request_aligned = self.cluster_by_field(fields_merged_request, messages_request_aligned, fid_merged_request)
            # change symbol names
            symbols_request_aligned = self.change_symbol_name(symbols_request_aligned)

这里面依次是合并当前字段之前的所有字段长度、生成聚类、改变每个聚类的符号名称,具体见每个函数的详细解释。

            # compute prob of m,s,d,v
            cluster_p_request[fid_request] = list()
            cluster_p_request[fid_request ].append(constraint_m_request.compute_constraint_message_similarity(symbols_request_aligned))#得到每个聚类的消息相似性约束列表,具体例子见compute_constraint_message_similarity函数。
            cluster_p_request[fid_request].append(self.compute_constraint_structure(symbols_request_aligned))
            cluster_p_request[fid_request].append(self.compute_constraint_dimension(symbols_request_aligned))
            cluster_p_request[fid_request].append(self.compute_constraint_value(symbols_request_aligned))
            cluster_size_request[fid_request] = [len(s.messages) for s in symbols_request_aligned.values()]
            #cluster_size_request存放每个聚类中的消息数量。例如cluster_size_request[1]:[8, 4, 4, 37],其中1是候选关键字字段号
            for fid_response in fid_list_response:
                #if fid_request != fid_response:
                #    continue
                logging.debug("[++] Test Response Field {0}-{1}".format(fid_request, fid_response))

                # merge other fields
                fields_merged_response = self.merge_nontest_fields(self.fields, fid_response)
                fid_merged_response = 0 if fid_response == 0 else 1

                # generate clusters
                symbols_response_aligned = self.cluster_by_field(fields_merged_response, messages_response_aligned, fid_merged_response)
                # change symbol names
                symbols_response_aligned = self.change_symbol_name(symbols_response_aligned)

                # compute prob of m,s,d,v
                if fid_response not in cluster_p_response:
                    cluster_p_response[fid_response] = list()
                    cluster_p_response[fid_response].append(constraint_m_response.compute_constraint_message_similarity(symbols_response_aligned))
                    cluster_p_response[fid_response].append(self.compute_constraint_structure(symbols_response_aligned))
                    cluster_p_response[fid_response].append(self.compute_constraint_dimension(symbols_response_aligned))
                    cluster_p_response[fid_response].append(self.compute_constraint_value(symbols_response_aligned))
                    cluster_size_response[fid_response] = [len(s.messages) for s in symbols_response_aligned.values()]

以上部分对应于请求端,原理啥的都一样。只是fid_list_response为[1, 2, 3, 5, 7, 8, 10, 11, 12, 13],共10项,少了第九项。

                # print msg numbers of each cluster
                logging.debug("Number of request symbols: {0}".format(len(symbols_request_aligned.values())))#这里是请求端的聚类数量
                for s in symbols_request_aligned.values():
                    logging.debug("  Symbol {0} msgs numbers: {1}".format(str(s.name), len(s.messages)))#这里是请求端每个聚类里的消息数量
                logging.debug("Number of response symbols: {0}".format(len(symbols_response_aligned.values())))#这里是响应端的聚类数量
                for s in symbols_response_aligned.values():
                    logging.debug("  Symbol {0} msgs numbers: {1}".format(str(s.name), len(s.messages)))#这里是响应端每个聚类里的消息数量
                # compute remote coupling probabilities #计算远程耦合概率
                rc = RemoteCoupling(messages_all=messages_aligned, symbols_request=symbols_request_aligned, symbols_response=symbols_response_aligned, direction_list=self.direction_list)
                rc.compute_pairs_by_directionlist()
                fid_pair = "{}-{}".format(fid_request, fid_response)
                p_r_request = rc.compute_constraint_remote_coupling(RemoteCoupling.TEST_TYPE_REQUEST)
                p_r_response = rc.compute_constraint_remote_coupling(RemoteCoupling.TEST_TYPE_RESPONSE)

得到的p_r_request和p_r_response分别对应请求端和客户端的远程耦合约束,以列表形式存放,对应于各个聚类的远程耦合约束值。

                logging.debug("[+] Observation Prob Results for pairs {}".format(fid_pair))
                p_m, p_s, p_d, p_v = cluster_p_request[fid_request][0], cluster_p_request[fid_request][1], cluster_p_request[fid_request][2], cluster_p_request[fid_request][3] #前面计算的消息相似性约束、结构一致性约束、规模约束、单聚类约束
                logging.debug("Request:\nPm: {0}\nPr: {1}\nPs: {2}\nPd: {3}\nPv: {4}".format(p_m, p_r_request, p_s, p_d, p_v))
                pairs_p_request[fid_pair] = [p_m, p_r_request, p_s, p_d, p_v]
                pairs_size_request[fid_pair] = cluster_size_request[fid_request]

**pairs_p_request[fid_pair]:**格式举例:记录fid_pair为1-1时,pairs_p_request为请求端的消息相似性约束、远程耦合约束、结构一致性约束、规模约束、单聚类约束

pairs_size_request[fid_pair]: 格式举例:记录当fid_pair为1-1时,pairs_size_request表示为请求端以fid_request为候选字段形成的每个聚类中的消息数量。

                p_m, p_s, p_d, p_v = cluster_p_response[fid_response][0], cluster_p_response[fid_response][1], cluster_p_response[fid_response][2], cluster_p_response[fid_response][3]
                logging.debug("Response:\nPm: {0}\nPr: {1}\nPs: {2}\nPd: {3}\nPv: {4}".format(p_m, p_r_response, p_s, p_d, p_v))
                pairs_p_response[fid_pair] = [p_m, p_r_response, p_s, p_d, p_v]
                pairs_size_response[fid_pair] = cluster_size_response[fid_response]

**pairs_p_response[fid_pair]:**格式举例:记录fid_pair为1-1时,pairs_p_response为响应端的消息相似性约束、远程耦合约束、结构一致性约束、规模约束、单聚类约束

pairs_size_response[fid_pair]: 格式举例:记录当fid_pair为1-1时,pairs_size_response表示为响应端以fid_request为候选字段形成的每个聚类中的消息数量。

                del rc
                del symbols_response_aligned #symbols
                del fields_merged_response
                gc.collect()
            del symbols_request_aligned
            del fields_merged_request
            gc.collect()
            #删除不用的变量,释放内存
        pairs_p = [pairs_p_request, pairs_p_response]
        pairs_size = [pairs_size_request, pairs_size_response]

        return pairs_p, pairs_size

pairs_p: 存放请求端和客户端的各种约束值。

**pairs_size :**存放请求端和客户端在当前候选关键字段下的每个聚类信息数量。

保存先验概率结果 save_observation_probabilities

    def save_observation_probabilities(self, pairs_p, pairs_size, direction):
        filename = "prob_request.txt" if direction == Constraint.TEST_TYPE_REQUEST else "prob_response.txt"
        filepath = os.path.join(self.output_dir, filename)
        
        fid_pair_list = sorted(pairs_p.keys(), key= lambda x: (int(x.split('-')[direction]), int(x.split('-')[1 - direction])))
        # Write into files
        with open(filepath, 'w') as fout:
            for fid_pair in fid_pair_list:
                fout.write("{} ".format(fid_pair))

                # write Pm/r/s/d/v
                for p_list in pairs_p[fid_pair]:
                    for p in p_list[:-1]:
                        fout.write("{},".format(p))
                    fout.write("{} ".format(p_list[-1]))

                for n in pairs_size[fid_pair][:-1]:
                    fout.write("{},".format(n))
                fout.write("{} ".format(pairs_size[fid_pair][-1]))
                fout.write("\n")

prob_request.txt: 以这个文件第一行为例统筹解释一下每个数据的含义。

例:1-2 0.5476190476190477,0.6615646258503401,0.04081632653061229,0.4707207207207207 1.0,1.0,1.0,1.0 1.0,1.0,1.0,1.0 1.0 1 8,4,4,37

1-2: 在请求端选取字段序号为1的候选字段作为关键字段 ,在响应端选取字段序号为2的候选字段作为关键字段

0.5476190476190477,0.6615646258503401,0.04081632653061229,0.4707207207207207: 在请求端选取字段序号为1的候选字段作为关键字段时,共得到四个聚类,每个聚类的消息相似性约束分别为 0.5476190476190477,0.6615646258503401,0.04081632653061229,0.4707207207207207

1.0,1.0,1.0,1.0: 在请求端选取字段序号为1的候选字段作为关键字段 ,在响应端选取字段序号为2的候选字段作为关键字段时,请求端所得到的四个聚类的远程耦合性约束分别为1.0,1.0,1.0,1.0

1.0,1.0,1.0,1.0: 在请求端选取字段序号为1的候选字段作为关键字段时,共得到四个聚类,每个聚类的结构一致性约束分别为1.0,1.0,1.0,1.0

1.0: 在请求端选取字段序号为1的候选字段作为关键字段时,共得到四个聚类,信息数量大于2的集群占总集群的比例为1.0

1: 在请求端选取字段序号为1的候选字段作为关键字段时如果有超过一个聚类,则为1

8,4,4,37: 在请求端选取字段序号为1的候选字段作为关键字段时,共得到四个聚类,每个聚类里面的信息数量为 8,4,4,37

合并其它没有在此次运行中测试的字段 merge_nontest_fields

# merge other fields that are not tested in this run
    def merge_nontest_fields(self, fields_origin, fid):
        logging.debug("[+] Merge Fields")
        fields_merged = list()
        fields = copy.deepcopy(fields_origin)
        fsize_total = 0
        for i in range(len(fields)):
            typename = fields[i].domain.dataType.typeName
            if typename != "Raw":
                logging.error("Field type is not Raw")
            typesize = fields[i].domain.dataType.size[1]
            # print(i, fid, typesize, type(typesize))

            if i == fid:
                if fsize_total > 0:#合并当前字段之前的所有字段
                    field = Field(Raw(nbBytes=fsize_total//8))
                    fields_merged.append(field)
                    fsize_total = 0
                if fid != (len(fields) - 1):    #记录当前字段大小
                    field = Field(Raw(nbBytes=typesize//8))
                    fields_merged.append(field)
                    field = Field()
                    fields_merged.append(field)
                else:
                    if type(typesize).__name__ == 'NoneType':
                        field = Field()
                    else:
                        field = Field(Raw(nbBytes=typesize//8))
                    fields_merged.append(field)
                break
            else:
                fsize_total += typesize   #记录当前fid字段之前的所有字段的总计长度      

        return fields_merged

比如,对于字段fid==5的时候,输出的fields_merged列表大小为3,分别是当前字段之前所有字段总的类型和尺寸、当前字段的类型和尺寸、一个该类型的零尺寸字段。为后边划分聚类做准备,生成cluster_by_field中所需要的ir,即:

merge_nontest_field[0].domain.dataType.typeName: Raw merge_nontest_field[0].domain.dataType.size[0] 96 merge_nontest_field[0].domain.dataType.size[1] 96
merge_nontest_field[1].domain.dataType.typeName: Raw merge_nontest_field[1].domain.dataType.size[0] 16 merge_nontest_field[1].domain.dataType.size[1] 16
merge_nontest_field[2].domain.dataType.typeName: Raw merge_nontest_field[2].domain.dataType.size[0] 0 merge_nontest_field[2].domain.dataType.size[1] None

判断消息长度has_short_msg

    def has_short_msg(self, messages, length):
        for message in messages:
            if len(message.data) <= length:
                return True
        return False

消息长度小于阈值,则返回true

估计不可能的字段filter_fields

    def filter_fields(self, fields, fid_list, messages):
        logging.debug("[++++] Filter Fields")
        fid_list_new = list()
        for fid in fid_list:
            logging.debug("\n[+] Test Field_{0}".format(fid))

            il, ir = 0, 0
            for i in range(fid):
                il += fields[i].domain.dataType.size[1] // 8
            ir = il + (fields[fid].domain.dataType.size[1] // 8)

fields中存放msa_fields_info文件中每个字段允许的长度区间(以字节为单位,即:Field(Raw(nbBytes=(typeinfo[1]//8, typeinfo[2]//8))))

fid_list中存放着msa_fields_info文件中字段信息为’D’(动态不可变长)的序号

fields[i].domain.dataType.size[1]存放的是msa_fields_info文件中允许字段的最大长度。

fields和fid_list变量的形成是在netplier.py的generate_fields_by_fieldsinfo函数中。

il和ir是为后边统计symbol长度的

问题:在generate_fields_by_fieldsinfo函数中已经将位数除以8变为字节了,为什么还要再除一次

            # -1: the test field is too long
            if (fields[fid].domain.dataType.size[1] // 8) > 10:
                logging.debug("The tested field is too long.")
                continue

当动态不可边长字段长度超过10时,舍弃

            # -2: message is too short to have field[fid]
            if self.has_short_msg(messages, ir):
                ##Check if the symbol_ntest side has field_merged.
                ##If the fields is empty, there will be InvalidParsingPathException error in computeFGP-clusterByKeyField
                logging.debug("Some messages doesn't have this field.")
                continue

检查字段长度是否为零

            #-3: too many symbols (>60%)
            # TODO
            f_values = [message.data[il:ir] for message in messages]
            percentage = len(messages) / len(set(f_values))
            if percentage < 1.5 or len(set(f_values)) > 50: # TODO: save time, but may cause error in small data set (modbus_100)
                logging.debug("There are too many symbols")
                continue

            fid_list_new.append(fid)

        #print(len(fid_list_new), fid_list_new)
        return fid_list_new

舍去symbol数量太多的情况。

计算结构一致性约束ps的值 compute_constraint_structure

    def compute_constraint_structure(self, symbols):
        logging.debug("[+] Compute observation probabilities of structure coherence")
        sn_list = [str(s.name) for s in symbols.values()]

        # if there is ony one msg, then it is always 1.0
        dict_result = dict() 
        for s in symbols.values():
            # compute the num of gaps shared by all msgs
            num_gap_extra = 0
            for i in range(len(s.messages[0].data)):
                valuelist = [message.data[i] for message in s.messages]
                if len(set(valuelist)) == 1 and valuelist[0] == '-':
                    num_gap_extra += 1
            #print("Num Extra Gaps: {}".format(num_gap_extra))
            
            # compute ave num of gaps
            num_gap = 0 
            for message in s.messages:
                num_gap += (message.data.count("-") - num_gap_extra)

            num_gap_ave = num_gap / len(s.messages)
            percentage_gap = num_gap_ave / (len(s.messages[0].data) - num_gap_extra)
            dict_result[s.name] = [1 - percentage_gap, num_gap_ave]

        p_s = list()
        for s in sn_list:
            p_s.append(dict_result[s][0])

        return p_s

我感觉这个计算过程有问题,和论文描述的不一样 输出的P_s是每个聚类的结构一致性约束结果,例如p_s:[1.0, 1.0, 1.0, 1.0]。

计算规模约束Pd 的值 compute_constraint_dimension

    # compute p_d
    def compute_constraint_dimension(self, symbols):
        logging.debug("[+] Compute observation probabilities of dimension")
        num_smallsymbols = 0
        for s in symbols.values():
            if len(s.messages) <= 2:
                num_smallsymbols += 1

        p = 1 - num_smallsymbols / len(symbols.values())
        p_d = [p]

        return p_d

截止到这里,并没有实现到文献中的地步,这里仅仅统计了信息数量大于2的集群占总集群的比例

计算Pv的值 compute_constraint_value

    def compute_constraint_value(self, symbols):
        # TODO: may not need it
        if len(symbols.values()) == 1:
            p = -1
        else:
            p = 1
        p_v= [p]

        return p_v

这里统计了按照关键字划分后只有一个聚类的情况,一般不会出现,因为静态字段不纳入候选关键字段考虑范围

根据当前字段进行聚类 cluster_by_field

    def cluster_by_field(self, fields, messages, fid_merged):
        logging.debug("[+] Generate Clusters")
        if fid_merged == 0:
            il = 0
            ir = fields[0].domain.dataType.size[1] // 8
        elif fid_merged == 1:
            il =fields[0].domain.dataType.size[1] // 8
            ir = il + (fields[1].domain.dataType.size[1] // 8)
        else:
            logging.error("Error: fid_merged should be 0 or 1")
        #其中il代表当前字段开始位置,ir代表当前字段结束位置

        #f_values存储了 请求端列表/响应端列表 中每条消息在当前字段位置的所有字节列表
        f_values = [message.data[il:ir] for message in messages]
        
        '''
        为每条消息分配到相应候选关键字段的字典列表中,
       '''
        dict_fv_i = dict()
        for i,fv in enumerate(f_values):
            if fv not in dict_fv_i:
                dict_fv_i[fv] = list()
            dict_fv_i[fv].append(i)
        
        #为dict_fv_i中每个字典索引项创建一个符号,即s
        symbols = collections.OrderedDict()
        for fv in dict_fv_i:
            s = Symbol(name=fv, messages=[messages[i] for i in dict_fv_i[fv]])
            symbols[fv] = s

        return symbols

关于dict_fv_i:

比如‘14’是其中候选关键字段的一项,且,前五项为
message[0].data[4:6]:,14
message[1].data[4:6]:,08
message[2].data[4:6]:,11
message[3].data[4:6]:,14
message[4].data[4:6]:,0b
则dict_fv_i[‘14’] = (‘0’,‘3’…)
dict_fv_i[‘08’] = (‘1’…)

关于 symbols:

继续上一个例子,则symbols[‘14’] = Symbol(‘14’, [messages[0],messages[3],…]) 其中messages[i] 指的是具体消息,请求端’14‘的有8条消息,因太长简写了,后边要用这个具体消息。

规范symbols的名称change_symbol_name

    def change_symbol_name(self, symbols):
        logging.debug("[+] Change symbol names")
        for keyFieldName, symbol in symbols.items():
            if type(keyFieldName).__name__ == "bytes":
                keyFieldName = binascii.unhexlify(keyFieldName)
                keyFieldName = keyFieldName.hex()
                symbol.name = str(keyFieldName) #将symbols中字节形式的索引改为字符串
            else:
                symbol.name = keyFieldName
            if len(symbol.name) > 40:
                md5 = hashlib.md5()
                md5.update(symbol.name.encode('utf-8')) #将超过40位的字典索引用MD5哈希值代替。
                symbol.name = str(md5.hexdigest())
        return symbols

keyFieldName指的是symbols中字典索引,例如’14‘。symbol是每个索引对应的数据。


probabilistic_inference.py

这部分是用来推理关键字段的

模块引入

from sklearn import preprocessing
import numpy as np
import copy
import logging

from factor_graph import MyFactorGraph

preprocessing:是用来进行数据处理的

**np:**用来解决多维的矩阵问题

ProbabilisticInference类初始化

class ProbabilisticInference:
    P_K2M, P_M2K = 0.8, 0.6 #0.8, 0.6
    P_K2R, P_R2K = 0.9, 0.6
    P_K2S, P_S2K = 0.9, 0.6 #0.9, 0.7
    P_K2D, P_D2K = 0.9, 0.6
    P_K2V, P_V2K = 0.9, 0.6

    BONUS_VALUE_X2K = 0.2

    def __init__(self, pairs_p, pairs_size):
        self.pairs_p = pairs_p # observation prob
        self.pairs_size = pairs_size

P_K2M等是指隐含概率,关于他的取值在文章中是这么描述的:“ In NETPLIER, probabilities p→ are set to be 0.8 for message similarity constraints and 0.9 for the others. Probabilities p← lies in [0.6, 0.8] depending on cluster sizes

**pairs_p:**先验概率的字典集,字典索引是一端字段序号和另一端字段序号‘{}-{}’,每个字典索引的值包含p_m, p_r, p_s, p_d, p_v

pairs_size:字典索引是一端字段序号和另一端字段序号‘{}-{}’,每个字典索引的值就是以当前字段划分后每个聚类的消息数量。

开始推断 execute

# inference
    def execute(self, fid_list = None):
        print("[++++++++] Infer the keyword")

        # update fid_list if it is specified
        if fid_list == None:
            fid_list = list(self.pairs_p.keys()) 
        else:
            fid_list = [fid for fid in fid_list if fid in self.pairs_p]
        logging.debug("fid_list: {}".format(fid_list)) #debug

确定fid_list的值,以dnp3为例,在这里是fid_list[‘1-1’, ‘2-2’, ‘3-3’, ‘5-5’, ‘7-7’, ‘8-8’, ‘10-10’, ‘11-11’, ‘12-12’, ‘13-13’]

        # compute implication probabilities
        self.p_implication = dict()
        for fid in self.pairs_p.keys():
            self.p_implication[fid] = self.compute_p_implication(self.pairs_p[fid], self.pairs_size[fid])#计算隐含概率
            #self.p_implication[fid] = self.compute_p_implication_weighted(self.pairs_p[fid], self.pairs_size[fid])
        self.p_implication = dict()
        for fid in self.pairs_p.keys():
            self.p_implication[fid] = self.compute_p_implication(self.pairs_p[fid], self.pairs_size[fid])
            #self.p_implication[fid] = self.compute_p_implication_weighted(self.pairs_p[fid], self.pairs_size[fid])

计算出了每个请求响应对‘{}-{}’下的隐含概率列表。

        p_observation = copy.deepcopy(self.pairs_p)
        #self.print_p_lists(fid_list, p_observation)

        # normalize observation prob
        p_observation = self.normalize_p_observation(p_observation) #对大于0的进行归一化,最后得到的还是字典集合
        #self.print_p_lists(fid_list, p_observation)
        # adjust observation/implication probabilities by cluster size
        logging.debug('[++++] Add bonus by size')
        # test_id: 0: m, 1: r, 2: s, 3: d, 4: v
        for fid in p_observation.keys():
            for test_id in [0, 1, 2]:
                p_observation[fid][test_id] = self.add_bonus_value(p_observation[fid][test_id], self.pairs_size[fid], 0.2)
                #self.p_implication[fid][1][test_id] = self.add_bonus_value(self.p_implication[fid][1][test_id], self.pairs_size[fid], ProbabilisticInference.BONUS_VALUE_X2K)
        # self.print_p_lists(fid_list, p_observation)

在这里,p_observation中每个聚类的大于0的先验概率都被添加了激励值。

        # deal with p < 0
        p_observation = self.update_invalid_p(p_observation)#以上全都在处理大于0的p值,在这里处理小于0的
        # self.print_p_lists(fid_list, p_observation, self.p_implication)
        # factor graph
        fg_result = dict()
        for fid in fid_list:
            pk_list = list() 
            fg = MyFactorGraph(p_observation=p_observation, p_implication=self.p_implication)
            # can test different constraints together
            # test type (m/r/s/d/v): 0: k2x & x2k, 1: k2x, 2: x2k, -1: not test
            pk_list.append(fg.compute_pk([0,0,0,0,0], fid)) #kv:mrsdv, vk: mrsdv #计算每个fid的值
                        ## Weighted Ave
            '''
            p_list_q_weighted, p_list_s_weighted, p_list_g_weighted = p_lists_dict_weighted[fid]
            p_lists_weighted = [p_list_q_weighted, p_list_s_weighted, p_list_d, p_list_v, p_list_g_weighted]
            pk_list.append(factorgraph.compute_pk([0,0,0,0,0], p_lists_weighted, p_values_const_dict[fid]))
            '''
            fg_result[fid] = pk_list

fg_result存储了作为每个‘{}-{}’成为真正关键字段的概率。

        logging.debug("\n[++++] Final Result")
        pk_list_size = len(list(fg_result.values())[0]) # num of different test
        for i in range(pk_list_size):
            result = dict()
            for fid in fg_result:
                result[fid] = fg_result[fid][i]
            logging.debug(sorted(result.items(), key=lambda x:x[1], reverse=True))#进行日志记录

        return self.get_fid_inferred(fg_result)#打印结果

添加激励值 add_bonus_value

    def add_bonus_value(self, p_list, size_list, bonus_value):
        size_sum = sum(size_list) #所有聚类消息数量的总和
        #p_list = [p + bonus_value * (s / size_sum) for p,s in list(zip(p_list, size_list))]
        result = list()
        for p,s in list(zip(p_list, size_list)):#p是一个聚类的某一先验概率,一个int值,s是同一聚类的消息数量
            if p > 0:
                result.append(p + bonus_value * (s / size_sum))
            else:
                result.append(p)

        return result

**list(zip(p_list, size_list)):**将p_list, size_list转换为成对的列表,假如p_list=[1,1,1,1],size_list=[2,2,2,2],则list(zip(p_list, size_list))=[(1,2),(1,2),(1,2),(1,2)]

计算隐含概率 compute_p_implication

    def compute_p_implication(self, p_lists, size_list):
        p_m, p_r, p_s, p_d, p_v = p_lists

        p_ktom = [ProbabilisticInference.P_K2M] * len(p_m)
        p_ktor = [ProbabilisticInference.P_K2R] * len(p_r)
        p_ktos = [ProbabilisticInference.P_K2S] * len(p_s)
        p_ktod = [ProbabilisticInference.P_K2D] * len(p_d)
        p_ktov = [ProbabilisticInference.P_K2V] * len(p_v)


        p_mtok = [ProbabilisticInference.P_M2K] * len(p_m)
        p_rtok = [ProbabilisticInference.P_R2K] * len(p_r)
        p_stok = [ProbabilisticInference.P_S2K] * len(p_s)
        p_dtok = [ProbabilisticInference.P_D2K] * len(p_d)
        p_vtok = [ProbabilisticInference.P_V2K] * len(p_v)

        p_ktox = [p_ktom, p_ktor, p_ktos, p_ktod, p_ktov]
        p_xtok = [p_mtok, p_rtok, p_stok, p_dtok, p_vtok]
        p_implication = [p_ktox, p_xtok]

        return p_implication

这里涉及的是列表乘法,以p_ktom为例,若ProbabilisticInference.P_K2M=0.8,len(p_m)=4,则p_ktom=[0.8,0.8,0.8,0.8]

规范化和标准化 normalize_p_observation

    def normalize_p_observation(self, p_observation):
        logging.debug("\n[++++] Normalize P_lists")

        observation_id = [0, 1, 2, 3] # 0: m, 1: r, 2: s, 3: d, 4: v
        for test_id in observation_id:
            p_list_total = list()
            for fid in p_observation:
                p_list_total += p_observation[fid][test_id]

p_list_total:是一个一维数组,包含所有请求响应消息对‘{}-{}’的某一种先验概率大集合

            # remove -1
            p_list_total = [p for p in p_list_total if p >= 0]
            if len(p_list_total) == 0:
                continue

只留下大于零的,对于消息相似性约束可能没有筛选性,但是可以筛选掉只能分到一个聚类的关键字段情况标记一下,这里会不会是导致最后聚类结果错误的地方。

            # TODO: compute the balance value automatically
            # TODO: compute the boundary value automatically
            if test_id in [0]: # ms
                #if len(p_list_total) > 1:
                #    p_list_total = self.standardize(p_list_total)
                #p_list_total = self.normalize_max_min(p_list_total)

                p_list_total_min = np.min(p_list_total)#得到p_list_total中最小值
                p_list_total_max = np.max(p_list_total)#得到p_list_total中最大值
                if p_list_total_min != p_list_total_max:
                    p_list_total = self.normalize_range(p_list_total, p_list_total_min, p_list_total_max, 0.2, 0.80) #[0.1, 0.95] #在最大最小值不相同的情况下将数据归一到[0.1,0.95]之间
                else:
                    p_balance = MyFactorGraph.compute_fg_threshold(ProbabilisticInference.P_K2M, ProbabilisticInference.P_M2K)
                    p_list_total = [p_balance for p in p_list_total]
                    #p_list_total = [0.5 for p in p_list_total]
            elif test_id in [1]: # rc
                p_list_total_min = np.min(p_list_total)
                p_list_total_max = np.max(p_list_total)
                if p_list_total_min != p_list_total_max:
                    #p_list_total = self.normalize_range(p_list_total, p_list_total_min, p_list_total_max, 0.2, 0.8)
                    p_list_total = self.normalize_range(p_list_total, 0, 1, 0.2, 0.8)#在最大最小值不相同的情况下将数据归一到[0.1,0.95]之间
                else:
                    p_balance = MyFactorGraph.compute_fg_threshold(ProbabilisticInference.P_K2R, ProbabilisticInference.P_R2K)
                    p_list_total = [p_balance for p in p_list_total]
                    #p_list_total = [0.5 for p in p_list_total]
            elif test_id in [2]: # structure
                #if len(p_list_total) > 1:
                #    p_list_total = self.standardize(p_list_total)

                p_list_total_min = np.min(p_list_total)
                p_list_total_max = np.max(p_list_total)
                if p_list_total_min != p_list_total_max:
                    #p_list_total = self.normalize_range(p_list_total, p_list_total_min, p_list_total_max, 0.2, 0.8)
                    p_list_total = self.normalize_range(p_list_total, 0, 1, 0.2, 0.8)#在最大最小值不相同的情况下将数据归一到[0.1,0.95]之间
                else:
                    p_balance = MyFactorGraph.compute_fg_threshold(ProbabilisticInference.P_K2S, ProbabilisticInference.P_S2K)
                    p_list_total = [p_balance for p in p_list_total]
                    #p_list_total = [0.5 for p in p_list_total]
            elif test_id in [3]: # d
                p_list_total_min = np.min(p_list_total)
                p_list_total_max = np.max(p_list_total)
                if p_list_total_min != p_list_total_max:
                    #p_list_total = self.normalize_range(p_list_total, p_list_total_min, p_list_total_max, 0.1, 0.75)
                    p_list_total = self.normalize_range(p_list_total, 0, 1, 0.1, 0.75)#在最大最小值不相同的情况下将数据归一到[0.1,0.95]之间
                else:
                    p_list_total = [0.95 for p in p_list_total]

原理同上。

            # write back to p_observation (with -1)
            count = 0
            for fid in p_observation:
                for i,p in enumerate(p_observation[fid][test_id]):
                    if p >= 0:
                        p_observation[fid][test_id][i] = p_list_total[count]
                        count += 1

将归一化后的值写入p_observation中,即完成了所有先验概率的归一化。最后返回归一化后的字典。

数据归一化 normalize_range

    #range1: original; range2: target
    def normalize_range(self, p_list, min1, max1, min2, max2):
        p_list = [min2 + (p - min1)*(max2 - min2)/(max1 - min1) for p in p_list]
        
        return p_list

将数据归一到[0.1,0.95]之间。公式为
p = m i n 2 + ( p − m i n 1 ) ( m a x 2 − m i n 2 ) m a x 1 − m i n 1 p= min2+\frac {(p-min1)(max2-min2)}{max1-min1} p=min2+max1min1(pmin1)(max2min2)

更新不合理的p值 update_invalid_p

    def update_invalid_p(self, p_observation):
        logging.debug("[++++] Update invalid p")
        for fid, p_lists in p_observation.items():
            # test_id: 0: m, 1: r, 2: s, 3: d, 4: v
            
            # TODO: only need to check ms. others could not be invalid
            for test_id in [0]:
                p_balance = MyFactorGraph.compute_fg_threshold(ProbabilisticInference.P_K2M, ProbabilisticInference.P_M2K)
                for i, p in enumerate(p_lists[test_id]):
                    if p < 0: 
                        p_lists[test_id][i] = p_balance if p < -1.5 else 0.4 #0.7272

p_lists[test_id]是指以某字段划分聚类后,所有聚类的某个先验概率(test_id,这里指消息相似性约束)列表,通过compute_fg_threshold将小于-1.5的p值换掉,对于[0.-1.5]之间的p,用0.4代替。

            # for r, remove -1 (the messages that have no request/response)
            # also update the num of p in p_implication
            for test_id in [1]:
                #i_filter = [i for i in range(len(p_lists[test_id])) if p_lists[test_id][i] > 0]
                i_filter = [i for i in range(len(p_lists[test_id])) if p_lists[test_id][i] >0 and self.pairs_size[fid][i] > 1]
                p_r = [p_lists[test_id][i] for i in i_filter]
                q_ktox_list_weighted_new = [self.p_implication[fid][0][test_id][i] for i in i_filter]
                q_xtok_list_weighted_new = [self.p_implication[fid][1][test_id][i] for i in i_filter]

                p_observation[fid][test_id] = p_r
                #self.p_implication[fid][0][test_id] = q_ktov_list_const_new
                #self.p_implication[fid][1][test_id] = q_vtok_list_const_new
                self.p_implication[fid][0][test_id] = q_ktox_list_weighted_new
                self.p_implication[fid][1][test_id] = q_xtok_list_weighted_new

对于远程耦合约束,移除-1,这类消息可能没有请求回应消息对。同时修改p_implication中的值

            # TODO: no need. could not be invalid
            for test_id in [2, 3]:
                for i, p in enumerate(p_lists[test_id]):
                    if p < 0:
                        # TODO: compute the balance value automatically
                        p_lists[test_id][i] = 0.4 #0.7272

对于结构一致性约束和规模约束,将小于0的p值统一设置为0.4

            # TODO
            for test_id in [4]:
                for i,p in enumerate(p_lists[test_id]):
                    if p < 0:
                        p_balance = MyFactorGraph.compute_fg_threshold(ProbabilisticInference.P_K2V, ProbabilisticInference.P_V2K)
                        p_lists[test_id][i] = p_balance - 0.45 #0.2
                    else:
                        p_lists[test_id][i] = 0.95 # TODO: remove it// 0.95

        return p_observation

对于pv,将小于0的p值设为0.2,大于0的p值设为0.95

得到字段推断结果 get_fid_inferred

    # TODO: add algorithms to infer the fid from fg results
    def get_fid_inferred(self, fg_result, max_num=1, precision=0.01):
        result = dict()
        for fid in fg_result:
            result[fid] = fg_result[fid][0] # only use the first test
        result_sorted = sorted(result.items(), key=lambda x:x[1], reverse=True)#将可能性从大到小排列
        fid_inferred = [result_sorted[0][0]]#首先取最大的作为关键字
        for i in range(1, len(result_sorted)):
            if result_sorted[i][1] - result_sorted[0][1]< precision:
                fid_inferred.append(result_sorted[i][0]) #可能性相差0.01的也可能是关键字,都拉进来
        fid_inferred = [int(fid.split("-")[0]) for fid in fid_inferred[:max_num]]
        #print(fid_inferred)

        return fid_inferred

factor_graph.py

方便计算大数据的。

模块引入

from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.inference import BeliefPropagation

FactorGraph:创建因子图,参考网址:因子图 — pgmpy 0.1.19 文档

DiscreteFactor:参考源码:pgmpy.factors.discrete.DiscreteFactor — pgmpy 0.1.19 文档

BeliefPropagation:参考网址:信念传播 — pgmpy 0.1.19 文档

主体,不再分析

class MyFactorGraph:

    def __init__(self, p_observation, p_implication):
        self.p_observation = p_observation
        self.p_implication = p_implication

    # Compute Pk
    # type_list: 0: k2x & x2k, 1: k2x, 2: x2k, -1: not test
    def compute_pk(self, type_list, fid):
        assert len(type_list) == 5, print("ComputePk Error: number of type_list should be 5")

        constraint_name = ['m', 'r', 's', 'd', 'v']
        '''
        m, r, s, d, v = type_list
        p_m, p_r, p_s, p_d, p_v = self.p_observation
        p_ktox, p_xtok = self.p_implication
        p_ktom, p_ktor, p_ktos, p_ktod, p_ktov = p_ktox
        p_mtok, p_rtok, p_stok, p_dtok, p_vtok = p_xtok
        '''
        fg = FactorGraph()
        fg.add_node('k')

        for i in range(len(type_list)):
            if type_list[i] == 0:
                fg = self.add_constraints_k2x_x2k(fg, self.p_observation[fid][i], self.p_implication[fid][0][i], self.p_implication[fid][1][i], constraint_name[i])
            elif type_list[i] == 1:
                fg = self.add_constraints_k2x(fg, self.p_observation[fid][i], self.p_implication[fid][0][i], constraint_name[i])
            elif type_list[i] == 2:
                fg = self.add_constraints_x2k(fg, self.p_observation[fid][i], self.p_implication[fid][1][i], constraint_name[i])
        '''
        if m == 0:
            fg = add_constraints_kv_vk(fg, p_m, p_ktom, p_mtok, 'm')
        elif m == 1:
            fg = add_constraints_kv(fg, p_m, p_mtok, 'm')
        elif m == 2:
            fg = add_constraints_vk(fg, p_m, p_mtok, 'm')

        if r == 0:
            fg = add_constraints_kv_vk(fg, p_r, p_ktor, p_rtok, 'r')
        elif r == 1:
            fg = add_constraints_kv(fg, p_r, p_ktor, 'r')
        elif r == 2:
            fg = add_constraints_vk(fg, p_r, p_rtok, 'r')

        if s == 0:
            fg = add_constraints_kv_vk(fg, p_s, p_ktos, p_stok, 's')
        elif s == 1:
            fg = add_constraints_kv(fg, p_s, p_ktos, 's')
        elif s == 2:
            fg = add_constraints_vk(fg, p_s, p_stok, 's')

        if d == 0:
            fg = add_constraints_kv_vk(fg, p_d, p_ktod, p_dtok, 'd')
        elif d == 1:
            fg = add_constraints_kv(fg, p_d, p_ktod, 'd')
        elif d == 2:
            fg = add_constraints_vk(fg, p_d, p_dtok, 'd')

        if v == 0:
            fg = add_constraints_kv_vk(fg, p_v, p_ktov, p_vtok, 'v')
        elif v == 1:
            fg = add_constraints_kv(fg, p_v, p_ktov, 'v')
        elif v == 2:
            fg = add_constraints_vk(fg, p_v, p_vtok, 'v')
        '''

        bp = BeliefPropagation(fg)

        #result = bp.query(variables=['k'])['k']
        #result = bp.query(variables=['k'], joint=False)['k']
        result = bp.query(variables=['k'])
        result.normalize()
        #print(result)

        return result.values[1]

    # Addd Constraints
    # k -> x
    def add_constraints_k2x(self, fg, p_x, p_ktox, x_name):
        for i in range(len(p_x)):
            p1 = p_x[i]
            p2 = p_ktox[i]
            x = '%s%d' % (x_name, i)
            fg.add_node(x)
            phi1 = DiscreteFactor([x], [2], [1 - p1, p1])
            phi2 = DiscreteFactor(['k', x], [2, 2], [p2, p2, 1 - p2, p2])
            fg.add_factors(phi1, phi2)
            fg.add_edges_from([(x, phi1), (x, phi2), ('k', phi2)])
        return fg

    # x -> k
    def add_constraints_x2k(self, fg, p_x, p_xtok, x_name):
        for i in range(len(p_x)):
            p1 = p_x[i]
            p3 = p_xtok[i]
            x = '%s%d' % (x_name, i)
            fg.add_node(x)
            phi1 = DiscreteFactor([x], [2], [1 - p1, p1])
            phi3 = DiscreteFactor(['k', x], [2, 2], [p3, 1 - p3, p3, p3])
            fg.add_factors(phi1, phi3)
            fg.add_edges_from([(x, phi1), (x, phi3), ('k', phi3)])
        return fg

    # k -> x & x -> k
    def add_constraints_k2x_x2k(self, fg, p_x, p_ktox, p_xtok, x_name):
        for i in range(len(p_x)):
            p1 = p_x[i]
            p2 = p_ktox[i]
            p3 = p_xtok[i]
            x = '%s%d' % (x_name, i)
            fg.add_node(x)
            phi1 = DiscreteFactor([x], [2], [1 - p1, p1])
            phi2 = DiscreteFactor(['k', x], [2, 2], [p2, p2, 1 - p2, p2])
            phi3 = DiscreteFactor(['k', x], [2, 2], [p3, 1 - p3, p3, p3])
            fg.add_factors(phi1, phi2, phi3)
            fg.add_edges_from([(x, phi1), (x, phi2), ('k', phi2), (x, phi3), ('k', phi3)])
        return fg

计算两个变量的平衡值 compute_fg_threshold

    # Compute the balance value for differnt p_kv/p_vk
    @staticmethod
    def compute_fg_threshold(p_kv, p_vk):
        p_t = (2 * p_kv * p_vk - p_vk) / (4 * p_kv * p_vk - p_kv - p_vk)
        return p_t

p t = 2 x y − y 4 x y − x − y p_t = \frac {2xy-y}{4xy-x-y} pt=4xyxy2xyy


clustering.py

后期评估,和真正的关键字对比,没啥重要不解释了。

# This file is part of NetPlier, a tool for binary protocol reverse engineering.
# Copyright (C) 2021 Yapeng Ye

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>

import numpy as np
from sklearn import metrics
import logging
import struct

class Clustering:
    def __init__(self, fields, protocol_type):
        self.fields = fields
        self.protocol_type = protocol_type
        
    def evaluation(self, clustering_result_true, clustering_result_method):
        print("[++++++++] Evaluate Clustering results")
        results_list = list()
        labels_true_list, labels_method_list = list(), list()
        for test_id in [0, 1]:
            results_true = clustering_result_true[test_id]
            if len(results_true) == 0:
                logging.error("The groundtruth could not be empty when evaluating clustering results")
                return

            dict_kwtoi = dict()
            for i,kw in enumerate(sorted(set(results_true), key=results_true.index)):
                dict_kwtoi[kw] = i
            labels_true = [dict_kwtoi[kw] for kw in results_true]
            labels_true_list.append(labels_true)

            
            results_method = clustering_result_method[test_id]
            dict_kwtoi = dict()
            for i,kw in enumerate(sorted(set(results_method), key=results_method.index)):
                dict_kwtoi[kw] = i
            labels_method = [dict_kwtoi[kw] for kw in results_method]
            labels_method_list.append(labels_method)

            h = metrics.homogeneity_score(labels_true, labels_method)
            c = metrics.completeness_score(labels_true, labels_method)
            v = metrics.v_measure_score(labels_true, labels_method) 
            
            test_direction = "Request" if test_id == 0 else "Response"
            print("{}:\nHomogeneity score: {:.8}\nCompleteness score: {:.8}\nV-measure score: {:.8}".format(test_direction, h, c, v))
            results_list.append([h, c, v])
        # total
        labels_true_request, labels_true_response = labels_true_list
        labels_method_request, labels_method_response = labels_method_list
        labels_true_total = labels_true_request + [kw + np.max(labels_true_request) + 1 for kw in labels_true_response]
        labels_method_total = labels_method_request + [kw + np.max(labels_method_request) + 1 for kw in labels_method_response]
        h = metrics.homogeneity_score(labels_true_total, labels_method_total)
        c = metrics.completeness_score(labels_true_total, labels_method_total)
        v = metrics.v_measure_score(labels_true_total, labels_method_total)
        print("Total:\nHomogeneity score: {:.8}\nCompleteness score: {:.8}\nV-measure score: {:.8}".format(h, c, v))
        results_list.append([h, c, v])

    def cluster_by_kw_true(self, messages):
        print("[++++++++] Cluster by True Keyword")
        results = list()

        if not self.protocol_type:
            logging.error("The protocol_type (-t) is required for computing the true clustering")
            return results
        
        for message in messages:
            kw = self.get_true_keyword(message)
            results.append(kw)
        
        return results

    def get_true_keyword(self, message):
        if self.protocol_type == "dhcp":
            kw = message.data[242:243]
        elif self.protocol_type == "dnp3":
            kw = message.data[12:13]
        elif self.protocol_type == "ftp":
            kw = re.split(" |-|\r|\n", message.data.decode())[0]
        elif self.protocol_type == "icmp":
            kw = message.data[0:2]
        elif self.protocol_type == "modbus":
            kw = message.data[7:8]
        elif self.protocol_type == "ntp":
            kw = message.data[0] & 0x07
        elif self.protocol_type == "smb":
            kw = message.data[4+4]
        elif self.protocol_type == "smb2":
            kw = struct.unpack("<H", message.data[4+12:4+12+2])[0]
        elif self.protocol_type == "tftp":
            kw = message.data[0:2]
        elif self.protocol_type == "zeroaccess":
            kw = message.data[4:8]
        else:
            logging.error("The protocol_type is unknown")

        if type(kw).__name__ == "bytes":
            kw = str(kw.hex())
        
        return kw

    def cluster_by_kw_inferred(self, fid_inferred_list, messages):
        print("[++++++++] Cluster by Inferred Keyword")
        results = [list() for message in messages]
        for fid_inferred in fid_inferred_list:
            il, ir = 0, 0
            for i in range(fid_inferred):
                il += self.fields[i].domain.dataType.size[1] // 8
            ir = il + (self.fields[fid_inferred].domain.dataType.size[1] // 8)

            for j in range(len(messages)):
                results[j].append(messages[j].data[il:ir])
        results = [''.join(result) for result in results]

        return results


problem

  • 以dnp3为例,最后的结果是推断关键字段f2,与真实中的关键字段f11(论文中对应于f7)不符,细究其实是多序列对比出了问题

经排除,多序列对比中字段划分无误,不保守的字段划分并不影响最终结果,具体见代码。应该是概率推理得大步骤有问题,字段f2的pv很明显就不符合真正关键字段的要求。

  • 结构一致性和相似性矩阵和论文里计算有出入
  • 对报文什么字段进行逆向的

2023.4.8
讨论感兴趣者联系我QQ2330547481.

  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值