HGAT半监督短文本分类的异构图注意网络论文阅读笔记及代码复现

Heterogeneous Graph Attention Networks for Semi-supervised Short Text Classification. Linmei Hu,Tianchi Yang, Chuan Shi, Houye Ji, Xiaoli Li论文阅读笔记及代码复现

未经准许,禁止转载
在这里插入图片描述
在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述

代码

main.py

# -*- coding: utf-8 -*-
import torch
import numpy as np
from network import HGraphConvolutionNet
import torch.nn as nn
import torch.optim as optim
import os
import gc
import matplotlib.pyplot as plt
import time

#固定随机数种子,方便复现结果
np.random.seed(1)
torch.random.manual_seed(1)
torch.set_default_tensor_type(torch.FloatTensor)

#是否使用GPU
gpu = True

 默认数据路径和模型保存路径
fast_snippet_path = 'data/snippets.pt'
snippet_path = 'data/snippets.npz'
best_model_path = 'model/best_model.pt'

#如果训练时GPU不可用,则使用CPU
train_device = 'cuda' if (torch.cuda.is_available() and gpu) else 'cpu'
test_device = 'cpu' # 强制使用CPU进行测试

 训练参数
learning_rate = 0.005
weight_decay = 5e-6
max_epochs = 1000
num_per_class = 100 # 训练集和验证集中每一类别的样本总和,训练时训练集和验证集各占一半,如100的话,训练集每一类别有50个,验证集每一类别有50个

 网络参数
num_hiddens = [512]
num_layer = 2

def block_diag(dd, dt, de, tt, te, ee):
    """
    组合特征矩阵,输出 N*N矩阵

    """
    _num_document = len(dd)
    _num_topic = len(tt)
    _num_entity = len(ee)
    _part_1 = torch.cat([dd, dt, de], dim=1)
    _part_2 = torch.cat([dt.T, tt, te], dim=1)
    _part_3 = torch.cat([de.T, te.T, ee], dim=1)
    return torch.cat([_part_1, _part_2, _part_3])


def normalize_adj(adj, device):
    """
    计算归一化的邻接矩阵,具体内容参考论文中公式1

    """
    n_samples = adj.shape[0]
    adj_norm = adj + torch.eye(n_samples).to(device)
    m = torch.diag(torch.sqrt(torch.tensor(1).to(device) / torch.sum(adj_norm, dim=1)))
    adj_norm = torch.matmul(torch.matmul(m, adj_norm), m)
    return adj_norm


def block_split(num_document, num_topic, num_entity):
    """
    辅助函数,生成各类型节点在邻接矩阵中的索引

    """
    x_1 = num_document
    x_2 = num_document + num_topic
    x_3 = num_document + num_topic + num_entity
    return [torch.arange(0, x_1), torch.arange(x_1, x_2), torch.arange(x_2, x_3)]


def fast_load(path=fast_snippet_path):
    """
    载入Snippets数据(快速模式,数据占空间较大)。
    输出为主题特征、文档特征、实体特征、主题-主题关系、主题-文档关系、主题-实体关系、文档-文档关系、文档-实体关系、实体-实体关系和类别信息。

    """
    _data = torch.load(path)
    _all_document_x = _data['features'][0]
    _all_topic_x = _data['features'][1]
    _all_entity_x = _data['features'][2]
    _all_adj_document = _data['adj'][0][0]
    _all_adj_document_topic = _data['adj'][0][1]
    _all_adj_document_entity = _data['adj'][0][2]
    _all_adj_topic = _data['adj'][1][1]
    _all_adj_topic_entity = _data['adj'][1][2]
    _all_adj_entity = _data['adj'][2][2]
    _all_labels = torch.argmax(_data['labels'], dim=1)
    return _all_document_x, _all_topic_x, _all_entity_x, _all_adj_document, _all_adj_document_topic, \
           _all_adj_document_entity, _all_adj_topic, _all_adj_topic_entity, _all_adj_entity, _all_labels


def load(path=snippet_path):
    """
    载入Snippets数据(正常模式,数据占空间小,已压缩)
    输出为主题特征、文档特征、实体特征、主题-主题关系、主题-文档关系、主题-实体关系、文档-文档关系、文档-实体关系、实体-实体关系和类别信息。

    """
    _data = np.load(path)
    _all_document_x = torch.tensor(_data['d_x'])
    _all_topic_x = torch.tensor(_data['t_x'])
    _all_entity_x = torch.tensor(_data['e_x'])
    _all_adj_document = torch.tensor(_data['adj_dd'])
    _all_adj_document_topic = torch.tensor(_data['adj_dt'])
    _all_adj_document_entity = torch.tensor(_data['adj_de'])
    _all_adj_topic = torch.tensor(_data['adj_tt'])
    _all_adj_topic_entity = torch.tensor(_data['adj_te'])
    _all_adj_entity = torch.tensor(_data['adj_ee'])
    _all_labels = torch.tensor(_data['d_y'])
    return _all_document_x, _all_topic_x, _all_entity_x, _all_adj_document, _all_adj_document_topic, \
           _all_adj_document_entity, _all_adj_topic, _all_adj_topic_entity, _all_adj_entity, _all_labels

#加载数据集
if os.path.exists('data/snippets.pt'):
    all_document_x, all_topic_x, all_entity_x, all_adj_document, all_adj_document_topic, all_adj_document_entity, \
    all_adj_topic, all_adj_topic_entity, all_adj_entity, all_labels = fast_load()
else:
    all_document_x, all_topic_x, all_entity_x, all_adj_document, all_adj_document_topic, all_adj_document_entity, \
    all_adj_topic, all_adj_topic_entity, all_adj_entity, all_labels = load()

 获取样本类别总数
num_classes = len(torch.unique(all_labels))

np_labels = all_labels.detach().numpy()

 划分数据集为训练集、验证集和测试集
train_dict = []
for i in range(num_classes):
    class_indices = np.where(np_labels == i)[0]
    if len(class_indices) > num_per_class:
        train_dict.extend(np.random.choice(class_indices, num_per_class))
    else:
        train_dict.extend(class_indices)
train_dict = set(train_dict)

num_documents = len(all_document_x)
indices = set(np.arange(num_documents))
indices.difference_update(train_dict)

train_dict = list(train_dict)
np.random.shuffle(train_dict)

indices = list(indices)
np.random.shuffle(indices)

train_dict.extend(indices)
indices = train_dict

num_train = num_per_class * num_classes // 2
num_val = num_per_class * num_classes // 2

train_indices = indices[0:num_train]
val_indices = indices[num_train:num_train + num_val]
test_indices = indices[num_train + num_val::]

num_features = [all_document_x.shape[1], all_topic_x.shape[1], all_entity_x.shape[1]]

#建立HGAT网络
net = HGraphConvolutionNet(num_features, num_hiddens, num_classes, num_layer, bias=False, device=train_device,
                           allow_attention=True, dropout=0.80)

#L2正则化(weight decay),不包含输出层
net_params = []
for idx, layer in enumerate(net.layers):
    if idx == len(net.layers) - 1:
        net_params.append({
            'params': layer.parameters(),
            'weight_decay': 0.0
        })
    else:
        net_params.append({
            'params': layer.parameters(),
            'weight_decay': weight_decay
        })
net_params.append({
    'params': net.attentions.parameters(),
    'weight_decay': weight_decay
})

#Adam优化
optimizer = optim.Adam(net_params, lr=learning_rate)

 损失函数
loss_function = nn.CrossEntropyLoss()

 如果已训练过模型,则加载已有模型
if os.path.exists(best_model_path):
    net.load_state_dict(torch.load(best_model_path, map_location=torch.device(train_device)))
else:
    # 获取训练数据和验证数据
    train_y = all_labels[train_indices].to(train_device)
    val_y = all_labels[val_indices].to(train_device)

    train_document_x = all_document_x[train_indices]
    train_adj_document = all_adj_document[train_indices, :][:, train_indices]
    train_adj_document_topic = all_adj_document_topic[train_indices, :]
    train_adj_document_entity = all_adj_document_entity[train_indices, :]

    val_document_x = all_document_x[val_indices]
    val_adj_document = all_adj_document[val_indices, :][:, val_indices]
    val_adj_document_topic = all_adj_document_topic[val_indices, :]
    val_adj_document_entity = all_adj_document_entity[val_indices, :]

    train_adj = block_diag(train_adj_document, train_adj_document_topic, train_adj_document_entity, all_adj_topic,
                           all_adj_topic_entity, all_adj_entity).to(train_device)
    train_adj = normalize_adj(train_adj, train_device)
    val_adj = block_diag(val_adj_document, val_adj_document_topic, val_adj_document_entity, all_adj_topic,
                         all_adj_topic_entity, all_adj_entity).to(train_device)
    val_adj = normalize_adj(val_adj, train_device)

    train_features = [train_document_x.to(train_device), all_topic_x.to(train_device), all_entity_x.to(train_device)]
    val_features = [val_document_x.to(train_device), all_topic_x.to(train_device), all_entity_x.to(train_device)]
    train_split = block_split(len(train_document_x), len(all_topic_x), len(all_entity_x))
    val_split = block_split(len(val_document_x), len(all_topic_x), len(all_entity_x))

    # 初始化绘图数据
    plot_x_data = []
    plot_train_acc_data = []
    plot_train_loss_data = []
    plot_val_acc_data = []
    plot_val_loss_data = []

    # 训练网络
    for epoch in range(max_epochs):
        net.train()
        optimizer.zero_grad()
        
        train_output = net(train_features, train_adj, train_split, normalize=False)[0:len(train_y)]
        train_loss = loss_function(train_output, train_y)
        train_pred_y = torch.argmax(train_output, dim=1)
        train_acc = torch.sum(train_y == train_pred_y).item() / len(train_y)

        loss = loss_function(train_output, train_y)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            net.eval()
            val_output = net(val_features, val_adj, val_split, normalize=False)[0:len(val_y)]
            val_loss = loss_function(val_output, val_y)
            val_pred_y = torch.argmax(val_output, dim=1)
            val_acc = torch.sum(val_y == val_pred_y).item() / len(val_y)
            print(
                f'Epoch={epoch}, train_loss={train_loss.item()}, train_acc={train_acc}, val_loss={val_loss.item()}, val_acc={val_acc}')

            plot_val_acc_data.append(val_acc)
            plot_train_acc_data.append(train_acc)
            plot_train_loss_data.append(train_loss.item())
            plot_val_loss_data.append(val_loss.item())
            plot_x_data.append(epoch)

            if (train_acc > 0.85 and val_acc > 0.85) or train_acc > 0.995:
                torch.save(net.state_dict(), best_model_path)
                print(f'Early stopping(train_acc): save network parameters.')
                break
    torch.save(net.state_dict(), best_model_path)

    plt.plot(plot_x_data, plot_train_acc_data, label='train')
    plt.plot(plot_x_data, plot_val_acc_data, label='val')
    plt.legend()
    plt.title('Training/Validation accuracy')
    plt.xlabel('Number of epochs')
    plt.ylabel('Accuracy')
    plt.savefig('figures/accuracy.pdf')
    plt.show()

    plt.plot(plot_x_data, plot_train_loss_data, label='train')
    plt.plot(plot_x_data, plot_val_loss_data, label='val')
    plt.legend()
    plt.xlabel('Number of epochs')
    plt.ylabel('Loss')
    plt.title('Training/Validation loss')
    plt.savefig('figures/loss.pdf')
    plt.show()

#测试集性能测试
gc.collect()
test_y = all_labels[test_indices].to(test_device)
test_document_x = all_document_x[test_indices]
test_adj_document = all_adj_document[test_indices, :][:, test_indices]
test_adj_document_topic = all_adj_document_topic[test_indices, :]
test_adj_document_entity = all_adj_document_entity[test_indices, :]
test_adj = block_diag(test_adj_document, test_adj_document_topic, test_adj_document_entity, all_adj_topic,
                      all_adj_topic_entity, all_adj_entity).to(test_device)

test_features = [test_document_x.to(test_device), all_topic_x.to(test_device), all_entity_x.to(test_device)]
test_split = block_split(len(test_document_x), len(all_topic_x), len(all_entity_x))
net.to(test_device)
net.device = test_device
for layer_attentions in net.attentions:
    for attention in layer_attentions:
        attention.device = test_device
test_output = net(test_features, test_adj, test_split)[0:len(test_y)]
test_loss = loss_function(test_output, test_y)
test_pred_y = torch.argmax(test_output, dim=1)
test_acc = torch.sum(test_y == test_pred_y).item() / len(test_y)
print(f'Test: test_loss={test_loss}, test_acc={test_acc}')

network.py

 -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.distributions import uniform
import torch.nn.functional as F


class GraphConvolutionLayer(nn.Module):
    """
    GCN层

    """
    def __init__(self, in_features, out_features, bias=True, device='cpu'):
        super(GraphConvolutionLayer, self).__init__()
        self.device = device
        self.in_features = in_features
        self.out_features = out_features
        self.allow_bias = bias

        # 初始化网络权重
        std = 1.0 / math.sqrt(out_features)
        sampler = uniform.Uniform(-std, std)
        self.weight = nn.Parameter(sampler.sample([in_features, out_features]), requires_grad=True).to(device)
        if bias:
            self.bias = nn.Parameter(sampler.sample([out_features]), requires_grad=True).to(device)
        else:
            self.bias = None

    def forward(self, h, adj_norm):
        h_w = torch.matmul(h, self.weight)
        out = torch.matmul(adj_norm, h_w)
        if self.bias is not None:
            out = out + self.bias
        return out

    def __repr__(self):
        return f'GraphConvolutionLayer(in_features={self.in_features},out_features={self.out_features},bias={self.allow_bias},device="{self.device}")'


class GraphConvolutionNet(nn.Module):
    """
    GCN网络,本次实验没有用到

    """
    def __init__(self, n_feature, n_hidden, n_class, n_layer=0, bias=False, device='cpu'):
        super(GraphConvolutionNet, self).__init__()
        self.device = device

        # 建立网络 n_hidden为隐层节点数,n_layer为隐层数
        self.layers = nn.ModuleList()
        self.layers.append(GraphConvolutionLayer(n_feature, n_hidden, bias))
        for _ in range(n_layer):
            self.layers.append(GraphConvolutionLayer(n_hidden, n_hidden, bias))
        self.layers.append(GraphConvolutionLayer(n_hidden, n_class, bias))
        self.layers.to(device)

    def forward(self, h, adj):
        # 计算归一化邻接矩阵
        n_samples = adj.shape[0]
        adj_norm = adj + torch.eye(n_samples).to(self.device)
        m = torch.diag(torch.sqrt(torch.tensor(1).to(self.device) / torch.sum(adj_norm, dim=1)))
        adj_norm = torch.matmul(torch.matmul(m, adj_norm), m)

        # 逐层计算输出
        hs = []
        for layer in self.layers[0:-1]:
            h = layer(h, adj_norm)
            h = torch.relu(h)

        # 输出层使用sigmoid激活函数
        h = self.layers[-1](h, adj_norm)
        h = torch.sigmoid(h)
        return torch.softmax(h, dim=1)


class Attention(nn.Module):
    """
    注意力模型

    """
    def __init__(self, in_features, hidden_features, type_idx, device='cpu'):
        super(Attention, self).__init__()
        self.type_idx = type_idx
        self.hidden_layer = nn.Linear(in_features, hidden_features).to(device) # 隐层为普通的神经网络层
        self.device = device
        # 初始化注意力权重向量
        sampler = uniform.Uniform(-1.0, 1.0)
        self.attention = nn.Parameter(sampler.sample([2 * hidden_features, 1]), requires_grad=True).to(device)

    def forward(self, h_all):
        # 计算注意力矩阵
        h = self.hidden_layer(h_all).transpose(0, 1) # 由于h_all是3D张量,此次矩阵乘法在前两个维度broadcast
        depth = h.shape[0] # 节点类型数,本次实验depth=3
        h = torch.cat([h, torch.stack([h[self.type_idx]] * depth, dim=0)], dim=2)

        weights = torch.matmul(h, self.attention.to(self.device)).transpose(0, 1)
        weights = nn.LeakyReLU()(weights)
        weights = torch.log_softmax(weights, dim=1)

        # 根据论文公式7计算最终特征
        h_out = torch.matmul(weights.transpose(1, 2), h_all).squeeze(1)
        return weights, h_out


class HGraphConvolutionNet(nn.Module):
    def __init__(self, n_features, n_hidden, n_class, n_layer, bias, device='cpu', dropout=0.95, allow_attention=True):
        super(HGraphConvolutionNet, self).__init__()
        # Network parameters.
        self.device = device
        self.layers = nn.ModuleList() # 最终会有 n_layer * n_module 个GCN层
        self.allow_attention = allow_attention
        self.attentions = nn.ModuleList() # 最终会有 n_layer * n_module 个Attention模块
        self.n_class = n_class
        self.n_layer = n_layer
        self.n_hidden = n_hidden
        self.n_features = n_features
        self.n_modules = len(n_features)
        self.dropout = dropout

        for idx_i in range(self.n_modules):
            # 初始化子模块
            module = nn.ModuleList()
            attention = nn.ModuleList()

            # 新建输入层
            input_layer = GraphConvolutionLayer(n_features[idx_i], n_hidden[0], bias)
            module.append(input_layer)

            attention_input = Attention(in_features=n_hidden[0], hidden_features=50, type_idx=idx_i, device=device)
            attention.append(attention_input)

            # 新建隐层
            for idx_j in range(len(n_hidden) - 1):
                hidden_layer = GraphConvolutionLayer(in_features=n_hidden[idx_j], out_features=n_hidden[idx_j + 1],
                                                     bias=bias)
                module.append(hidden_layer)
                attention_hidden = Attention(in_features=n_hidden[idx_j + 1], hidden_features=50, type_idx=idx_i,
                                             device=device)
                attention.append(attention_hidden)
            # 新建输出层
            output_layer = GraphConvolutionLayer(in_features=n_hidden[-1], out_features=n_class, bias=bias)
            module.append(output_layer)
            attention_output = Attention(in_features=n_class, hidden_features=50, type_idx=idx_i, device=device)
            attention.append(attention_output)

            # 保存子模块
            self.layers.append(module)
            self.attentions.append(attention)

        # CPU or CUDA.
        self.layers.to(device)
        self.attentions.to(device)

    def forward(self, hs, adj, splits, normalize=True):
        h_out = None
        n_samples = adj.shape[0]
        # 如果没有归一化,则使用论文公式1进行归一化
        if normalize:
            adj_norm = adj + torch.eye(n_samples).to(self.device)
            m = torch.diag(torch.sqrt(torch.tensor(1).to(self.device) / torch.sum(adj_norm, dim=1)))
            adj_norm = torch.matmul(torch.matmul(m, adj_norm), m)
        else:
            adj_norm = adj
        if self.allow_attention:
            # 计算带注意力的网络,参考论文公式7
            for idx_layer in range(self.n_layer):
                h_out_list = []
                for idx_module_1 in range(self.n_modules):
                    h_layers = [] # 各类型间的特征
                    for idx_module_2 in range(self.n_modules):
                        A_r = adj_norm[splits[idx_module_1], :][:, splits[idx_module_2]] # 获取邻接矩阵的一部分
                        layer = self.layers[idx_module_2][idx_layer] # 计算隐层特征
                        if h_out is None:
                            h_r = hs[idx_module_2]
                        else:
                            h_r = h_out[splits[idx_module_2], :]
                        h_r_next = layer(h_r, A_r)
                        h_layers.append(h_r_next)
                    _, h_out_part = self.attentions[idx_module_1][idx_layer](torch.stack(h_layers, dim=1))
                    h_out_list.append(h_out_part)
                h_out = torch.cat(h_out_list)

                if idx_layer != self.n_layer - 1:
                    h_out = torch.relu(h_out)
                    h_out = F.dropout(h_out, self.dropout, self.training)
        else:
            # 计算不带注意力的网络,参考论文公式2
            h_next = None
            for idx_layer in range(self.n_layer):
                for idx_module in range(self.n_modules):
                    A_r = adj_norm[:, splits[idx_module]]
                    layer = self.layers[idx_module][idx_layer]
                    if h_out is None:
                        h_r = hs[idx_module]
                    else:
                        h_r = h_out[splits[idx_module], :]
                    h_r_next = layer(h_r, A_r)
                    if h_next is None:
                        h_next = h_r_next
                    else:
                        h_next = h_next + h_r_next
                h_out = h_next
                if idx_layer != self.n_layer - 1:
                    h_out = torch.relu(h_out)
                    h_out = F.dropout(h_out, self.dropout, self.training)
                h_next = None
        # 输出层使用sigmoid激活
        h_out = torch.sigmoid(h_out)
        return torch.softmax(h_out, dim=1)

  • 15
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值