针对DGL的few-shot数据集划分方法

目标

针对DGL数据集做few-shot问题时候,需要将数据集划分成nshot的train,val,test。要求划分出多个task,每个task的train,val,test比例一致。同一个task内部,train,val,test无交叉;task间,train,val,test可以交叉。

Graph-Level

数据格式

以MUTAG为例:

{
   'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
      ndata_schemes={
   'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
      edata_schemes={
   }), 'label': tensor(0)}

label代表graph的类别

代码

#按照label做升序
def cmp(a,b):
    if a['label']<b['label']:
        return -1
    if a['label']>b['label']:
        return 1
    return 0

def few_shot_split_graphlevl(dataset,train_shotnum,val_shotnum,classnum,tasknum):
    #task中的train,val,test无交叉;但不同task之间的train,val,test可以交叉
    #train_shotnum代表train中的shotnum,val_shotnum同理
    #classnum代表总共有几类
    #先将dataset中的数据按照class排序,然后随机从中选出数据来放入train val,剩下的放进test即可
    train=[]
    val=[]
    test=[]
    dataset=sorted(dataset,key=functools.cmp_to_key(cmp))
    length=len(dataset)
    #统计每类各有多少张图
    classcount=torch.zeros(classnum)
    #统计每类的第一个元素在dataset中的索引位置
    class_start_index=torch.zeros(classnum)
    label_before=1e6
    count=0
    for data in dataset:
        classcount[data['label']]+=1
        if label_before != data['label']:
            label_before=data['label']
            class_start_index[data['label']]=count
        count+=1
    #print('classcount:',classcount)
    #print(class_start_index)
    class_start_index=class_start_index.int()
    for task in range(tasknum):
        train_index=
要使用DGL创建自己的数据集来用于图分类,可以按照以下步骤操作: 1.准备数据:将图形数据存储为图形文件或使用Python脚本生成图形数据。确保每个节点都有唯一的ID,并且图形数据以节点和边列表的形式存储。 2.使用DGL创建Graph对象:使用DGL创建一个空图形对象,并使用节点和边列表填充它。 3.添加标签:为每个节点添加标签,这将成为我们的目标变量。标签可以是任何类型的标记,例如整数或字符串。 4.划分数据集:将数据集划分为训练集、验证集和测试集。 5.使用DGLDataset创建自定义数据集:使用DGL提供的DGLDataset类创建自定义数据集。在这个类中,你需要实现__init__、__getitem__和__len__方法。__init__方法用于加载数据,__getitem__方法用于返回单个数据样本,__len__方法用于返回数据集的大小。 6.创建数据加载器:使用DGL提供的Dataloader类创建数据加载器。 7.训练和测试:使用创建的数据加载器进行训练和测试。 以下是一个简单的示例,演示如何使用DGL创建自己的数据集: ```python import dgl from dgl.data import DGLDataset from dgl.dataloading import GraphDataLoader class MyDataset(DGLDataset): def __init__(self): super().__init__(name='mydataset') # Load data and labels # data is a list of tuples (src, dst) # labels is a list of integers self.data, self.labels = load_data_and_labels() # Create a DGL graph object self.graph = dgl.graph((self.data[:, 0], self.data[:, 1])) # Add labels to nodes self.graph.ndata['label'] = self.labels # Split dataset into train, validation, and test sets self.train_idx, self.valid_idx, self.test_idx = split_dataset() def __getitem__(self, idx): return self.graph, self.graph.ndata['label'][idx] def __len__(self): return len(self.graph) # Create a data loader dataset = MyDataset() train_loader = GraphDataLoader(dataset, batch_size=32, shuffle=True) # Train and test the model for epoch in range(num_epochs): for batched_graph, labels in train_loader: # Train the model pass # Test the model for batched_graph, labels in test_loader: # Evaluate the model pass ``` 在这个示例中,我们首先使用load_data_and_labels函数加载数据和标签,然后使用dgl.graph函数创建一个DGL图对象。我们将标签作为节点数据添加到图形中,并使用split_dataset函数将数据集划分为训练、验证和测试集。 接下来,我们使用MyDataset类创建自定义数据集,并使用GraphDataLoader类创建数据加载器。在训练和测试循环中,我们使用数据加载器加载数据,并用它们训练和测试模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值