检验torch_geometric各种图结构能否正常使用GPU

笔者的环境:

pytorch: 1.8.0+cu111
torch_geometric : 2.0.2
other: torch_scatter-2.0.6; torch_sparse-0.6.10; torch_spline_conv-1.2.1; torch_cluster-1.5.9

torch_geometric集成了各种各样图结构,但是因为不同的图结构会依赖于不同的后端计算(e.g., torch_cluster,torch_scatter),使得就算成功安装torch_geometric之后,有些图结构可以使用GPU进行推理,但是有些则不能。(比方说这种issue: github-issue)

用下面这段代码模拟三种图结构的推理训练反传:

import torch
from torch import nn
from torch_geometric.nn.conv.rgcn_conv import RGCNConv
from torch_geometric.nn.conv.gat_conv import GATConv
from torch_geometric.nn.conv.graph_conv import GraphConv

batch = 10

conv_1 = RGCNConv(10,10,2,num_bases=2)
conv_2 = GATConv(10,10,heads=2,concat=False,dropout=0.0,add_self_loops=True)
conv_3 = GraphConv(10,10)
cls = nn.Linear(10,20)

conv_1.to("cuda:0")
conv_2.to("cuda:0")
conv_3.to("cuda:0")
cls.to("cuda:0")

for i in range(batch):
	input = torch.randn((2,10)).requires_grad_().to("cuda:0")
	index = torch.tensor([[0,1],[1,0]]).long().to("cuda:0")
	edge_type = torch.tensor([0,1]).long().to("cuda:0")
	
	out_1 = conv_1(input,index,edge_type)
	out_2 = conv_2(out_1,index)
	out_3 = conv_3(out_2,index)
	out = cls(out_3)
	loss = torch.sum(out)
	
	print(torch.autograd.grad(loss,input,retain_graph=True))
	loss.backward()

输出结果如下:
在这里插入图片描述
在这里插入图片描述
如果上述代码运行正常没有报错,且输出正常,说明当前环境可以使用绝大部分torch_geometric的图实现。


狠一点的话,还可以把batch_size开很大,然后去看看显卡显存占用 (直接把下面这段代码粘到terminal里面):

import torch
from torch import nn
from torch_geometric.nn.conv.rgcn_conv import RGCNConv
from torch_geometric.nn.conv.gat_conv import GATConv
from torch_geometric.nn.conv.graph_conv import GraphConv

batch = 100000

conv_1 = RGCNConv(10,10,2,num_bases=2)
conv_2 = GATConv(10,10,heads=2,concat=False,dropout=0.0,add_self_loops=True)
conv_3 = GraphConv(10,10)
cls = nn.Linear(10,20)

conv_1.to("cuda:0")
conv_2.to("cuda:0")
conv_3.to("cuda:0")
cls.to("cuda:0")

for i in range(batch):
 input = torch.randn((2,10)).requires_grad_().to("cuda:0")
 index = torch.tensor([[0,1],[1,0]]).long().to("cuda:0")
 edge_type = torch.tensor([0,1]).long().to("cuda:0")
 out_1 = conv_1(input,index,edge_type)
 out_2 = conv_2(out_1,index)
 out_3 = conv_3(out_2,index)
 out = cls(out_3)
 loss = torch.sum(out)
 print(torch.autograd.grad(loss,input,retain_graph=True))
 loss.backward()

说明没问题:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值