Segmented GRAPH-BERT代码理解①script_preprocess.py

该博客详细介绍了如何处理图数据,包括加载原始图、使用StratifiedKFold进行交叉验证,以及进行Weisfeiler-Lehman图核计算。内容涵盖了图的节点填充、元特征提取和标签编码,为后续的图分类任务做了充分的预处理工作。
摘要由CSDN通过智能技术生成
dataset_name = 'MUTAG'
strategy = 'isolated_segment'

对应参数设置为:

max_graph_size = 40

1加载原始图

data_obj.load_type = 'Raw'

对应加载作者整理好的data.txt文件为数据集

通过MethodProcessRaw里的load_raw_graph_list生成graph_list, graph_size_list

g_list = [] 
label_dict = {} 
feat_dict = {}
graph_size_list = [] 
with open(file_path, 'r') as f:
    n_g = int(f.readline().strip())
    # 188
    for i in range(n_g):
        row = f.readline().strip().split()
        graph_size_list.append(int(row[0]))
        n, l = [int(w) for w in row]
        if not l in label_dict:
            mapped = len(label_dict)
            label_dict[l] = mapped
        g = nx.Graph()
        n_edges = 0
        for j in range(n):
            g.add_node(j)
            row = f.readline().strip().split()
            row = [int(w) for w in row]
            n_edges += row[1]
            for k in range(2, len(row)):
                g.add_edge(j, row[k])
		# 通过这一部分走过下面不止两个元素的row,从而可以正常给n, l 赋值
        assert len(g) == n
        g_list.append({'graph': g, 'label': l})

这时label_dict为{2: 0, 0: 1}
key是图的标签,value是索引。也即有类别标签的图分类问题
生成图g以最后一个为例用plt可视化
在这里插入图片描述
返回得graph_list
在这里插入图片描述
graph_size_list(实为列向量,下图为其转置)
在这里插入图片描述

通过MethodProcessRaw里的separate_data生成train_idx_dic, test_idx_dict
这里使用StratifiedKFold函数做k折交叉验证,采用分层划分的方法(分层随机抽样思想),验证集中不同类别占比与原始样本的比例保持一致,故StratifiedKFold在做划分的时候需要传入标签特征(即graph_list中’label’里的2和0)。

train_idx_dict = {} 
test_idex_dict = {}
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
labels = [graph['label'] for graph in graph_list] 
fold_count = 1 
for train_idx, test_idx in skf.split(np.zeros(len(labels)), labels):
    train_idx_dict[fold_count] = train_idx
    test_idex_dict[fold_count] = test_idx
    fold_count += 1

其中train_idx, test_idx如下图
在这里插入图片描述
最终返回得到train_idx_dic在这里插入图片描述
test_idx_dict在这里插入图片描述

最后同时返回graph_size_list中的最大值28

2 WL

data_obj.load_type = 'Processed'

对应加载上面生成的字典
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
从而计算WL
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3图填充&元特征/标签提取

加载任务一生成的大字典作为数据集

传入graph_dict
在这里插入图片描述
和max_graph_size = 40

def padding(self, graph_dict, max_size):
        node_tags = [max_size+1]*max_size
        node_degrees = [0] * max_size
        wl_tags = [0]*max_size
        #这里用最大尺寸为标准扩充所有图,以0占位
        w_list = []

        graph = graph_dict['graph']
        if graph_dict['label'] not in self.label_dict:
            self.label_dict[graph_dict['label']] = len(self.label_dict)
        y = self.label_dict[graph_dict['label']]
        wl_code = graph_dict['node_WL_code']

        node_list = list(graph.nodes)
        idx_map = {j: i for i, j in enumerate(node_list)}
        for i in range(max_size):
            w = [0.0] * max_size
            if i < len(node_list):
                node = node_list[i]
                node_tags[i] = node
                node_degrees[i] = graph.degree(node)
                wl_tags[i] = wl_code[i]
                neighbor_list = list(graph.neighbors(node))
                for neighbor in neighbor_list:
                    if idx_map[neighbor] >= max_size: continue
                    w[idx_map[neighbor]] = 1.0
            w_list.append(w)

        return node_tags, node_degrees, wl_tags, w_list, y

返回以最后一组的结果的转置为例
node_tags
在这里插入图片描述
node_degrees
在这里插入图片描述
wl_tags
在这里插入图片描述
w_list
在这里插入图片描述
y=0

for i in range(len(self.data['graph_list'])):
      graph = self.data['graph_list'][i]
      tag, degree, wl, w, y = self.padding(graph, max_graph_size)
      processed_graph_data.append({'id': i, 'tag': tag, 'degree': degree, 'weight': w, 'wl_tag': wl, 'y': y})
  self.data['processed_graph_data'] = processed_graph_data

最终返回结果是在原数据字典中加入’processed_graph_data’
转成numpy格式转置如下
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值