如果想构建自己的数据集,应该继承dgl.data.DGLDataset
类,并且实现下面的方法:
__getitem__(self,i)
:得到数据集的第i
个数据,__len__(self)
:数据集的大小process(self)
:从硬盘加载和处理原始数据
这里使用一个小数据集Zachary’s Karate Club network,包含:
menbers.csv
文件包含每个成员的属性interactions.csv
文件包含两个成员的关系
import urllib.request
import pandas as pd
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/members.csv', './members.csv')
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/interactions.csv', './interactions.csv')
members = pd.read_csv('./members.csv')
members.head()
interactions = pd.read_csv('./interactions.csv')
interactions.head()
我们将成员视作节点,关系视作边,年龄视作节点的属性,加入的club作为节点的标签,边的权重作为变的属性:
import pandas as pd
import dgl
from dgl.data import DGLDataset
import torch
import os
class KarateClubDataset(DGLDataset):
def __init__(self):
super().__init__(name='karate_club')
def process(self):
nodes_data = pd.read_csv('./karate/members.csv')
edges_data = pd.read_csv('./karate/interactions.csv')
node_features = torch.from_numpy(nodes_data['Age'].to_numpy())
node_labels = torch.from_numpy(nodes_data['Club'].astype('category').cat.codes.to_numpy()) # 将Club属性变为category类型,往往作为label 并且转为0,1
edge_features = torch.from_numpy(edges_data['Weight'].to_numpy())
edges_src = torch.from_numpy(edges_data['Src'].to_numpy())
edges_dst = torch.from_numpy(edges_data['Dst'].to_numpy())
self.graph = dgl.graph((edges_src, edges_dst), num_nodes=nodes_data.shape[0])
self.graph.ndata['feat'] = node_features
self.graph.ndata['label'] = node_labels
self.graph.edata['weight'] = edge_features
# If your dataset is a node classification dataset, you will need to assign
# masks indicating whether a node belongs to training, validation, and test set.
n_nodes = nodes_data.shape[0]
n_train = int(n_nodes * 0.6)
n_val = int(n_nodes * 0.2)
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
val_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[:n_train] = True
val_mask[n_train:n_train + n_val] = True
test_mask[n_train + n_val:] = True
self.graph.ndata['train_mask'] = train_mask
self.graph.ndata['val_mask'] = val_mask
self.graph.ndata['test_mask'] = test_mask
def __getitem__(self, i):
return self.graph
def __len__(self):
return 1
dataset = KarateClubDataset()
graph = dataset[0]
print(graph)