一、使用InMemoryDataset数据集类
import os. path as osp
import torch
from torch_geometric. data import ( InMemoryDataset, download_url)
from torch_geometric. io import read_planetoid_data
class PlanetoidPubMed ( InMemoryDataset) :
r""" 节点代表文章,边代表引文关系。
训练、验证和测试的划分通过二进制掩码给出。
参数:
root (string): 存储数据集的文件夹的路径
transform (callable, optional): 数据转换函数,每一次获取数据时被调用。
pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。
"""
url = 'https://gitee.com/rongqinchen/planetoid/tree/master/data'
def __init__ ( self, root, transform= None , pre_transform= None ) :
super ( PlanetoidPubMed, self) . __init__( root, transform, pre_transform)
self. data, self. slices = torch. load( self. processed_paths[ 0 ] )
@property
def raw_dir ( self) :
return osp. join( self. root, 'raw' )
@property
def processed_dir ( self) :
return osp. join( self. root, 'processed' )
@property
def raw_file_names ( self) :
names = [ 'x' , 'tx' , 'allx' , 'y' , 'ty' , 'ally' , 'graph' , 'test.index' ]
return [ 'ind.pubmed.{}' . format ( name) for name in names]
@property
def processed_file_names ( self) :
return 'data.pt'
def download ( self) :
for name in self. raw_file_names:
download_url( '{}/{}' . format ( self. url, name) , self. raw_dir)
def process ( self) :
data = read_planetoid_data( self. raw_dir, 'pubmed' )
data = data if self. pre_transform is None else self. pre_transform( data)
torch. save( self. collate( [ data] ) , self. processed_paths[ 0 ] )
def __repr__ ( self) :
return '{}()' . format ( self. name)
dataset = PlanetoidPubMed( 'dataset/Cora' )
data = dataset[ 0 ]
print ( dataset. num_classes)
print ( dataset[ 0 ] . num_nodes)
print ( dataset[ 0 ] . num_edges)
print ( dataset[ 0 ] . num_features)
二、节点预测任务
定义一个GAT图神经网络
from torch_geometric. nn import GATConv, Sequential
from torch. nn import Linear, ReLU
import torch. nn. functional as F
class GAT ( torch. nn. Module) :
def __init__ ( self, num_features, hidden_channels_list, num_classes) :
super ( GAT, self) . __init__( )
torch. manual_seed( 12345 )
hns = [ num_features] + hidden_channels_list
conv_list = [ ]
for idx in range ( len ( hidden_channels_list) ) :
conv_list. append( ( GATConv( hns[ idx] , hns[ idx+ 1 ] ) , 'x, edge_index -> x' ) )
conv_list. append( ReLU( inplace= True ) , )
self. convseq = Sequential( 'x, edge_index' , conv_list)
self. linear = Linear( hidden_channels_list[ - 1 ] , num_classes)
def forward ( self, x, edge_index) :
x = self. convseq( x, edge_index)
x = F. dropout( x, p= 0.5 , training= self. training)
x = self. linear( x)
return x
实例化模型并设置参数
model = GAT( num_features= dataset. num_features, hidden_channels_list= [ 200 , 100 ] , num_classes= dataset. num_classes)
print ( model)
optimizer = torch. optim. Adam( model. parameters( ) , lr= 0.01 , weight_decay= 5e - 4 )
criterion = torch. nn. CrossEntropyLoss( )
进行训练
def train ( ) :
model. train( )
optimizer. zero_grad( )
out = model( data. x, data. edge_index)
loss = criterion( out[ data. train_mask] , data. y[ data. train_mask] )
loss. backward( )
optimizer. step( )
return loss
for epoch in range ( 1 , 201 ) :
loss = train( )
print ( f'Epoch: {epoch:03d}, Loss: {loss:.4f}' )
测试结果
def test ( ) :
model. eval ( )
out = model( data. x, data. edge_index)
pred = out. argmax( dim= 1 )
test_correct = pred[ data. test_mask] == data. y[ data. test_mask]
test_acc = int ( test_correct. sum ( ) ) / int ( data. test_mask. sum ( ) )
return test_acc
test_acc = test( )
print ( f'Test Accuracy: {test_acc:.4f}' )
三、边预测任务实践
import torch
from torch_geometric. nn import GCNConv
class Net ( torch. nn. Module) :
def __init__ ( self, in_channels, out_channels) :
super ( Net, self) . __init__( )
self. conv1 = GCNConv( in_channels, 128 )
self. conv2 = GCNConv( 128 , out_channels)
def encode ( self, x, edge_index) :
x = self. conv1( x, edge_index)
x = x. relu( )
return self. conv2( x, edge_index)
def decode ( self, z, pos_edge_index, neg_edge_index) :
edge_index = torch. cat( [ pos_edge_index, neg_edge_index] , dim= - 1 )
return ( z[ edge_index[ 0 ] ] * z[ edge_index[ 1 ] ] ) . sum ( dim= - 1 )
def decode_all ( self, z) :
prob_adj = z @ z. t( )
return ( prob_adj > 0 ) . nonzero( as_tuple= False ) . t( )
原文地址