目标
针对DGL数据集做few-shot问题时候,需要将数据集划分成nshot的train,val,test。要求划分出多个task,每个task的train,val,test比例一致。同一个task内部,train,val,test无交叉;task间,train,val,test可以交叉。
Graph-Level
数据格式
以MUTAG为例:
{
'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
ndata_schemes={
'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={
}), 'label': tensor(0)}
label代表graph的类别
代码
#按照label做升序
def cmp(a,b):
if a['label']<b['label']:
return -1
if a['label']>b['label']:
return 1
return 0
def few_shot_split_graphlevl(dataset,train_shotnum,val_shotnum,classnum,tasknum):
#task中的train,val,test无交叉;但不同task之间的train,val,test可以交叉
#train_shotnum代表train中的shotnum,val_shotnum同理
#classnum代表总共有几类
#先将dataset中的数据按照class排序,然后随机从中选出数据来放入train val,剩下的放进test即可
train=[]
val=[]
test=[]
dataset=sorted(dataset,key=functools.cmp_to_key(cmp))
length=len(dataset)
#统计每类各有多少张图
classcount=torch.zeros(classnum)
#统计每类的第一个元素在dataset中的索引位置
class_start_index=torch.zeros(classnum)
label_before=1e6
count=0
for data in dataset:
classcount[data['label']]+=1
if label_before != data['label']:
label_before=data['label']
class_start_index[data['label']]=count
count+=1
#print('classcount:',classcount)
#print(class_start_index)
class_start_index=class_start_index.int()
for task in range(tasknum):
train_index=