#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2019-08-27 11:14:53
# @Author : Mengji Zhang (zmj_xy@sjtu.edu.cn)
import torch
from torch.nn import Parameter
from torch.nn.modules.module import Module
from torch import nn
import torch.nn.functional as F
import math
USE_CUDA=torch.cuda.is_available()
device=torch.device("cuda" if USE_CUDA else "cpu")
class RGCLayer(Module):
# params: (8285, 16, 66, 40, True, 0)
def __init__(self,input_dim,h_dim,num_rel,num_base,featureless,drop_prob):
super(RGCLayer,self).__init__()
self.num_base=num_base
self.input_dim=input_dim
self.num_rel=num_rel
self.h_dim=h_dim
self.featureless=featureless
self.drop_prob=drop_prob
# 增加限制:(0, support]
if num_base < 0 or num_base > num_rel:
num_base = num_rel
# if num_base>0:
if num_base > 0 and num_base < self.num_rel: # 只有当num_base < support时,才分解
# 系数矩阵
self.W=Parameter(torch.empty(input_dim*self.num_base,h_dim,dtype=torch.float32,device=device)) # (8285*40, 16)
# 基向量矩阵
self.W_comp=Parameter(torch.empty(num_rel,num_base,dtype=torch.float32,device=device)) # (66, 40)
else:
self.W=Parameter(torch.empty(input_dim*self.num_rel,h_dim,dtype=torch.float32,device=device)) # (8285*66, 16)
self.B=Parameter(torch.FloatTensor(h_dim)) # (16, )
self.reset_parameters()
def reset_parameters(self):
# std = 1./math.sqrt(self.W.size(1))
# self.W.data.uniform_(-std, std)
nn.init.xavier_uniform_(self.W)
if self.num_base>0 and self.num_base < self.num_rel :
nn.init.xavier_uniform_(self.W_comp)
self.B.data.fill_(0)
def forward(self,vertex,A): # 刚开始vertex等于None
supports=[] # 存的是feature
nodes_num=A[0].shape[0] # 8285
for i,adj in enumerate(A): # i=[0, 66]
if not self.featureless: # 如果有特征,即featureless = False。 默认无特征,即featureless = True
supports.append(torch.spmm(adj,vertex))
else:
# 把在关系r下:结点的邻接关系当作feature
supports.append(adj)
supports=torch.cat(supports,dim=1) # 并排concat (8285, 8285 * 66)
# 如果有基分解正则化
# if self.num_base>0:
if self.num_base > 0 and self.num_base < self.num_rel:
# matmul():当B超过3维时,A分别与B的对应矩阵相乘
# 因为W_comp是66种关系都共用的权重矩阵,因此,这66种关系分别为同一个空间的基。
# 那么在这个空间下的向量[66,1]就是这66种边的权重系数
# 首先有40种权重系数,然后有16种权重系数,每种权重系数意味着在hidden_layer的某个维度的aggregate的方法,hidden_layer的dim是8285
# 即我们现在定义了每个dim的aggregate的方法
# (66, 40) * (8285, 40, 16) - > (8285, 66, 16) = (N, R, h_dim)
# W_comp 是基矩阵V,W是系数ar,b
V=torch.matmul(self.W_comp, torch.reshape(self.W, (self.num_base,self.input_dim,self.h_dim)).permute(1,0,2))
V=torch.reshape(V,(self.input_dim*self.num_rel,self.h_dim)) # V = (feature_dim=8285, 66, 16) -> (8285 * 66, 16)
# 然后将这种aggregate方法应用在hidden_layer的每个dim上,总共有8285 * 66个dim
# (8285, input_dim=8285 * 66) * (8285 * 66, output_dim=16) -> (8285, 16)
output=torch.spmm(supports,V) # W*x
# 否则W*x
else:
# (8285, 8285 * 66) * (8285*66, 16) -> (8285, 16)
output=torch.spmm(supports,self.W)
if self.featureless:
# I -> [8285, 8285]
temp=torch.ones(nodes_num).to(device)
temp_drop=F.dropout(temp,self.drop_prob)
# (16, 8285) * (8285, 8285) -> (16, 8285) -> (8285, 16)
output=(output.transpose(1,0)*temp_drop).transpose(1,0)
output+=self.B
return output
R-GCN
最新推荐文章于 2023-03-16 10:18:47 发布