DGL大图采样报错记录:Expect argument “nodes[“_N“]“ to have data type torch.int32. But got torch.int64.

作者在使用DGL处理大规模图数据进行采样训练时,遇到一个错误,即DGL期望节点数据类型为torch.int32,但实际输入为torch.int64。经过多日排查,错误地将所有数据强制转换为torch.int32仍然报错。最终发现,将数据类型改为torch.int64后,程序成功运行,表明之前的问题可能是由于其他原因导致的误解。
摘要由CSDN通过智能技术生成

我被这个问题折腾了好几天,因为在处理一个很庞大的图数据,需要采样训练,然后数据从原始数据处理成图数据,导入DGL采样模块中,结果就一直报错如下。

    nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 91, in prepare_tensor_dict
    return {
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 92, in <dictcomp>
    key: prepare_tensor(g, val, '{}["{}"]'.format(name, key))
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 36, in prepare_tensor
    raise DGLError(
dgl._ffi.base.DGLError: Expect argument "nodes["_N"]" to have data type torch.int32. But got torch.int64.

报错意思是:dgl期望输入的图节点数据类型为torch.int32,但是实际输入的数据类型为torch.int64。

也就是dgl需要torch.int32,但是我不小心把输入的数据类型弄成了torch.int64,我一直以为是不是我数据类似问题,结果就一直从头到尾把所有数据类型都强制转换成torch.int32,结果还是报错如上,折腾了真的好几天,一直不明白为啥!!!!

结果就在刚刚,我不小心把数据类型强制转换成torch.int64,结果结果,居然跑通了!!!!

什么鬼啊这是,这个BUG怎么来的,我吐了啊,折腾了我好几天啊。。。。。

下面贴一段代码,大家可以自己试试:

import torch
import numpy as np
import dgl

us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int32), torch.tensor(vs, dtype=torch.int32)))

sampler = dgl.dataloading.NeighborSampler([64, 32])

train_nids = np.arange(1000)

train_loader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    num_workers=4)

for input_nodes, output_nodes, blocks in train_loader:
    print(input_nodes)

上面会报错,你只需要修改成torch.int64即可,如下。

import torch
import numpy as np
import dgl

us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int64), torch.tensor(vs, dtype=torch.int64)))

sampler = dgl.dataloading.NeighborSampler([64, 32])

train_nids = np.arange(1000)

train_loader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    num_workers=4)

for input_nodes, output_nodes, blocks in train_loader:
    print(input_nodes)

运行结果
在这里插入图片描述

正常采样了。。。。。

我吐了,不带这么坑人的。。。。。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值