Graph Convolutional Networks for Text Classification
一、模型
图示:
图的定义:
二、代码
总体目录:
datasets ---存放数据
temp ---存放变量
model.py ---模型
preprocess.py ---预处理
train.py ---训练
model.py 模型
import torch
from torch import nn
class GCNLayer(nn.Module):
def __init__(self, input_dim, output_dim, dropout=0.,
act=torch.relu, bias=False, first=False):
super(GCNLayer, self).__init__()
self.act = act
self.first = first
self.bias = None
if bias:
self.bias = nn.Parameter(torch.rand((output_dim,)))
self.weight = nn.Parameter(torch.rand((input_dim, output_dim)))
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight, gain=1)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x, support):
if self.first:
x = self.weight
else:
x = self.dropout(x).mm(self.weight)
x = support.mm(x)
if self.bias is not None:
x += self.W.bias
return self.act(x)
class Model(nn.Module):
def __init__(self, num_nodes, num_classes, hidden_dim, dropout=0.):
super(Model, self).__init__()
self.g1 = GCNLayer(num_nodes, hidden_dim, dropout, torch.relu, False, True)
self.g2 = GCNLayer(hidden_dim, num_classes, dropout, lambda x: x, False)
def forward(self, support):
x = self.g1(None, support)
x = self.g2(x, support)
return x
preprocess.py 预处理部分:load_dataset函数需要自己构造
import re
from collections import Counter
from tqdm import tqdm
import math
import scipy.sparse as sp
from nltk.corpus import stopwords
import numpy as np
import torch
import joblib
import random
# mr 1701424
dataset = "mr"
min_count = 5
stop_words = set(stopwords.words('english'))
if dataset == "mr":
min_count = 0
stop_words = {}
window_size = 20
def load_dataset(dataset):
with open(f"datasets/{dataset}.txt&