HGAT半监督短文本分类的异构图注意网络论文阅读笔记及代码复现

Heterogeneous Graph Attention Networks for Semi-supervised Short Text Classification. Linmei Hu,Tianchi Yang, Chuan Shi, Houye Ji, Xiaoli Li论文阅读笔记及代码复现

未经准许,禁止转载
在这里插入图片描述
在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述

代码

main.py

# -*- coding: utf-8 -*-
import torch
import numpy as np
from network import HGraphConvolutionNet
import torch.nn as nn
import torch.optim as optim
import os
import gc
import matplotlib.pyplot as plt
import time

#固定随机数种子,方便复现结果
np.random.seed(1)
torch.random.manual_seed(1)
torch.set_default_tensor_type(torch.FloatTensor)

#是否使用GPU
gpu = True

 默认数据路径和模型保存路径
fast_snippet_path = 'data/snippets.pt'
snippet_path = 'data/snippets.npz'
best_model_path = 'model/best_model.pt'

#如果训练时GPU不可用,则使用CPU
train_device = 'cuda' if (torch.cuda.is_available() and gpu) else 'cpu'
test_device = 'cpu' # 强制使用CPU进行测试

 训练参数
learning_rate = 0.005
weight_decay = 5e-6
max_epochs = 1000
num_per_class = 100 # 训练集和验证集中每一类别的样本总和,训练时训练集和验证集各占一半,如100的话,训练集每一类别有50个,验证集每一类别有50个

 网络参数
num_hiddens = [512]
num_layer = 2

def block_diag(dd, dt, de, tt, te, ee):
    """
    组合特征矩阵,输出 N*N矩阵

    """
    _num_document = len(dd)
    _num_topic = len(tt)
    _num_entity = len(ee)
    _part_1 = torch.cat([dd, dt, de], dim=1)
    _part_2 = torch.cat([dt.T, tt, te], dim=1)
    _part_3 = torch.cat([de.T, te.T, ee], dim=1)
    return torch.cat([_part_1, _part_2, _part_3])


def normalize_adj(adj, device):
    """
    计算归一化的邻接矩阵,具体内容参考论文中公式1

    """
    n_samples = adj.shape[0]
    adj_norm = adj + torch.eye(n_samples).to(device)
    m = torch.diag(torch.sqrt(torch.tensor(1).to(device) / torch.sum(adj_norm, dim=1)))
    adj_norm = torch.matmul(torch.matmul(m, adj_norm), m)
    return adj_norm


def block_split(num_document, num_topic, num_entity):
    """
    辅助函数,生成各类型节点在邻接矩阵中的索引

    """
    x_1 = num_document
    x_2 = num_document + num_topic
    x_3 = num_document + num_topic + num_entity
    return [torch.arange(0, x_1), torch.arange(x_1, x_2), torch.arange(x_2, x_3)]


def fast_load(path=fast_snippet_path):
    """
    载入Snippets数据(快速模式,数据占空间较大)。
    输出为主题特征、文档特征、实体特征、主题-主题关系、主题-文档关系、主题-实体关系、文档-文档关系、文档-实体关系、实体-实体关系和类别信息。

    """
    _data = torch.load(path)
    _all_document_x = _data['features'][0]
    _all_topic_x = _data['features'][1]
    _all_entity_x = _data['features'][2]
    _all_adj_document = _data['adj'][0][0]
    _all_adj_document_topic = _data['adj'][0][1]
    _all_adj_document_entity = _data['adj'][0][2]
    _all_adj_topic = _data['adj'][1][1]
    _all_adj_topic_entity = _data['adj'][1][2]
    _all_adj_entity = _data['adj'][2][2]
    _all_labels = torch.argmax(_data['labels'], dim=1)
    return _all_document_x, _all_topic_x, _all_entity_x, _all_adj_document, _all_adj_document_topic, \
           _all_adj_document_entity, _all_adj_topic, _all_adj_topic_entity, _all_adj_entity, _all_labels


def load(path=snippet_path):
    """
    载入Snippets数据(正常模式,数据占空间小,已压缩)
    输出为主题特征、文档特征、实体特征、主题-主题关系、主题-文档关系、主题-实体关系、文档-文档关系、文档-实体关系、实体-实体关系和类别信息。

    """
    _data = np.load(path)
    _all_document_x = torch.tensor(_data['d_x'])
    _all_topic_x = torch.tensor(_data['t_x'])
    _all_entity_x = torch.tensor(_data['e_x'])
    _all_adj_document = torch.tensor(_data['adj_dd'])
    _all_adj_document_topic = torch.tensor(_data['adj_dt'])
    _all_adj_document_entity = torch.tensor(_data['adj_de'])
    _all_adj_topic = torch.tensor(_data['adj_tt'])
    _all_adj_topic_entity = torch.tensor(_data['adj_te'])
    _all_adj_entity = torch.tensor(_data['adj_ee'])
    _all_labels = torch.tensor(_data['d_y'])
    return _all_document_x, _all_topic_x, _all_entity_x, _all_adj_document,
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值