只需一个float就可以,这样从文件里读取string类型的数据的时候,直接float即可转换。
if __name__ == '__main__': a = "0.2" print(float(a) / 10)
输出:
0.02
集合set的操作,对于动态变化的训练集操作
有时候训练集一直在变化,比如标签纠正,有些训练样本要增加或者删除,这时候用set集合操作去快速获取非训练样本。
import torch
if __name__ == '__main__':
# path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset')
# dataset = Amazon(path, 'Computers')
# print(dataset[0])
all_index = torch.arange(7)
train_index = torch.tensor([0,2,4])
all_list = all_index.cpu().numpy().tolist()
train_list = train_index.cpu().numpy().tolist()
rest_set = set(all_list) - set(train_list)
rest_tensor = torch.tensor(list(rest_set),dtype=torch.int64)
print(all_index,rest_tensor)
结果如下所示:
tensor([0, 1, 2, 3, 4, 5, 6]) tensor([1, 3, 5, 6])