针对DGL的few-shot数据集划分方法

目标

针对DGL数据集做few-shot问题时候,需要将数据集划分成nshot的train,val,test。要求划分出多个task,每个task的train,val,test比例一致。同一个task内部,train,val,test无交叉;task间,train,val,test可以交叉。

Graph-Level

数据格式

以MUTAG为例:

{
   'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
      ndata_schemes={
   'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
      edata_schemes={
   }), 'label': tensor(0)}

label代表graph的类别

代码

#按照label做升序
def cmp(a,b):
    if a['label']<b['label']:
        return -1
    if a['label']>b['label']:
        return 1
    return 0

def few_shot_split_graphlevl(dataset,train_shotnum,val_shotnum,classnum,tasknum):
    #task中的train,val,test无交叉;但不同task之间的train,val,test可以交叉
    #train_shotnum代表train中的shotnum,val_shotnum同理
    #classnum代表总共有几类
    #先将dataset中的数据按照class排序,然后随机从中选出数据来放入train val,剩下的放进test即可
    train=[]
    val=[]
    test=[]
    dataset=sorted(dataset,key=functools.cmp_to_key(cmp))
    length=len(dataset)
    #统计每类各有多少张图
    classcount=torch.zeros(classnum)
    #统计每类的第一个元素在dataset中的索引位置
    class_start_index=torch.zeros(classnum)
    label_before=1e6
    count=0
    for data in dataset:
        classcount[data['label']]+=1
        if label_before != data['label']:
            label_before=data['label']
            class_start_index[data['label']]=count
        count+=1
    #print('classcount:',classcount)
    #print(class_start_index)
    class_start_index=class_start_index.int()
    for task in range(tasknum):
        train_index=
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值