Text Level Graph Neural Network for Text Classification(https://arxiv.org/pdf/1910.02356.pdf)
一、文章概述
1.1 模型图
图1:单一文本的图形结构“他为您感到非常骄傲。” 为了便于显示,在此图中,我们为节点“非常”设置了p = 2(节点和边用红色表示),而其他节点设置了p = 1(蓝色表示)。 在实际情况下,会话期间的p值是唯一的。 图中的所有参数均来自全局共享表示矩阵,该矩阵显示在图的底部。
1.2构图方式
其中N和E是图的节点集和边集,而N中的单词表示和E中的边权重均取自全局共享矩阵。 p表示连接到图形中每个单词的相邻单词的数量。 此外,我们将训练集中出现次数少于k次的边缘统一映射到“公共”边缘,以使参数得到充分训练。
1.3消息传递
卷积可以从局部特征中提取信息(LeCun等,1989)。 在图域中,卷积是通过频谱方法(Bruna等人,2014; Henaff等人,2015)或非频谱方法(Duvenaud等人,2015)实现的。 本文采用一种称为消息传递机制(MPM)的非频谱方法(Gilmer et al。,2017)进行卷积。 MPM首先从相邻节点收集信息,并根据其原始表示形式和所收集的信息来更新其表示形式,定义为:
其中Mn∈R d是节点n从其邻居收到的消息; max是归约函数,它将每个维上的最大值组合在一起以形成新的向量作为输出。 N p n表示代表原始文本中n的最近p个单词的节点; ean∈R 1是从节点a到节点n的边缘权重,可以在训练过程中进行更新; rn∈R d表示节点n的前一个表示。 ηn∈R 1是节点n的可训练变量,指示应保留rn的信息量。 r 0 n表示节点n的更新表示。
MPM使节点的表示受邻域影响,这意味着这些表示可以从上下文中获取信息。 因此,即使对于多义词,上下文中的确切含义也可以通过来自邻居的加权信息的影响来确定。 此外,文本级图的参数取自全局共享矩阵,这意味着表示也可以像其他基于图的模型一样带入全局信息。
1.4文本表示
最后,使用文本中所有节点的表示来预测文本的标签:
1.5训练损失
训练的目的是使地面真实标签与预测标签之间的交叉熵损失最小:
二、参数设置
我们将节点表示的维数设置为300,并使用随机向量或Glove进行初始化(Pennington等,2014)。 将第2.1节中讨论的k设置为2。我们使用Adam优化器(Kingma和Ba,2014),初始学习率为10-3,而L2权重衰减设置为10-4。 在密集层之后应用保持概率为0.5的压降。 我们的模型的批量大小为32。如果验证损失连续10个周期没有减少,我们将停止训练。
对于基准模型,我们使用其原始文件或实施中的默认参数设置。 对于使用预训练词嵌入的模型,我们使用了300维GloVe词嵌入。
三、代码
3.1代码结构
先运行preprocess.py,再运行train.py
temp中有Google-vectors-negative.bin文件需要提前下载
链接:https://pan.baidu.com/s/1lCp-A3yrmxS4VOy1bnrEog
提取码:plcy
3.2config.py
import os
ROOT = os.path.dirname(os.path.abspath(__file__))
RECORD_PATH = ROOT + "/record"
TEMP_PATH = ROOT + "/temp"
args = {
"num_words": 53867,
"num_edges": 3121369,
"num_classes": 20,
"embedding_dim": 300,
"batch_size": 32,
"dropout": 0.5,
"L2": 1e-4,
"lr": 1e-3,
"window_size": 2
}
3.3dataset.py
from torch.utils.data import Dataset
from config import TEMP_PATH
import numpy as np
import torch
class MyDataset(Dataset):
def __init__(self, cate='train', seq_len=496):
TRAIN_NUM = 11314
masters_list = np.load(TEMP_PATH + r"/masters_list.npy")[:, :seq_len]
salves_list_list = np.load(TEMP_PATH + r"/salves_list_list.npy")[:, :seq_len]
edges_list_list = np.load(TEMP_PATH + r"/edges_list_list.npy")[:, :seq_len]
targets = np.load(TEMP_PATH + r"/targets.npy")
if cate == 'train':
masters_list = masters_list[:TRAIN_NUM]
salves_list_list = salves_list_list[:TRAIN_NUM]
edges_list_list = edges_list_list[:TRAIN_NUM]
targets = targets[:TRAIN_NUM]
else:
masters_list = masters_list[TRAIN_NUM:]
salves_list_list = salves_list_list[TRAIN_NUM:]
edges_list_list = edges_list_list[TRAIN_NUM:]
targets = targets[TRAIN_NUM:]
self.masters_list = torch.tensor(masters_list, dtype=torch.long)
self.salves_list_list = torch.tensor(salves_list_list, dtype=torch.long)
self.edges_list_list = torch.tensor(edges_list_list, dtype=torch.long)
self.targets = torch.tensor(targets, dtype=torch.long)
def __getitem__(self, item):
return (self.masters_list[item],
self.salves_list_list[item],
self.edges_list_list[item],
self.targets[item])
def __len__(self):
return len(self.masters_list)
3.4model.py
from torch import nn
import torch
class TextLevelGNN(nn.Module):
def __init__(self, word2vector):
super(TextLevelGNN, self).__init__()
num_words = 53867
num_edges = 3121369
num_classes = 20
embedding_dim = 300
dropout = 0.5
if word2vector is None:
self.w_embed = nn.Embedding(num_words + 1, embedding_dim, num_words)
else:
self.w_embed = nn.Embedding.from_pretrained(torch.from_numpy(word2vector).float(), False, num_words)
self.e_embed = nn.Embedding(num_edges + 1, 1, padding_idx=num_edges)
self.k_embed = nn.Embedding(num_words + 1, 1, padding_idx=num_words)
self.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(embedding_dim, num_classes, bias=True),
nn.LeakyReLU(inplace=True)
)
def forward(self, masters, slaves_list, edges_list):
Rn = self.w_embed(masters)
Ra = self.w_embed(slaves_list)
Ean = self.e_embed(edges_list)
Mn = (Ra * Ean).max(dim=2)[0]
Nn = self.k_embed(masters)
x = (1 - Nn) * Mn + Nn * Rn
x = self.fc(x.sum(dim=1))
return x
if __name__ == '__main__':
num_words = 50000
num_edges = 3000000
num_classes = 20
embedding_dim = 300
window_size = 2
batch_size = 64
seq_len = 320
masters = torch.randint(0, num_words + 1, (batch_size, seq_len))
slaves_list = torch.randint(0, num_words + 1, (batch_size, seq_len, window_size * 2))
edges_list = torch.randint(0, num_edges + 1, (batch_size, seq_len, window_size * 2))
model = TextLevelGNN(None)
y = model(masters, slaves_list, edges_list)
print(y.shape)
3.5preprocess.py
import gensim
import joblib
import numpy as np
from collections import Counter
from nltk.tokenize import word_tokenize
from sklearn.datasets import fetch_20newsgroups
from config import TEMP_PATH
def get_data() -> tuple:
data = fetch_20newsgroups(subset='all', random_state=1)
texts, targets = data.data, list(data.target)
return texts, targets
def is_valid_word(word: str) -> bool:
count1, count2 = word.count("'"), word.count(".")
if count1 > 1 or count2 > 3: return False
for c in word:
if not c.isnumeric() and not c.isalpha() and c not in "'.": return False
return (len(word) - count1 - count2) != 0
def padding_seq(seq: list, padding_idx: int, padding_len: int) -> list:
if len(seq) > padding_len: return seq[:padding_len]
return seq + [padding_idx] * (padding_len - len(seq))
def get_word2vec(word2index: dict, embedding_dim=300):
google = gensim.models.KeyedVectors.load_word2vec_format(TEMP_PATH + r"\golve_pretraining.bin", binary=True)
google_set = set(list(google.vocab))
index2vec = {i: google[w] if w in google_set else list(np.random.random(embedding_dim))
for w, i in word2index.items()}
word2vec = [index2vec[i] for i in range(len(index2vec))] + [[0.] * embedding_dim]
return word2vec
print("get dataset")
texts, targets = get_data()
words_list = [word_tokenize(text) for text in texts]
print("get word_set")
MIN_COUNT = 5
word2count = Counter([w for words in words_list for w in words])
word_count = sorted(list(word2count.items()), key=lambda x: x[1], reverse=True)
word_count = [[w, c] for w, c in word_count if c >= MIN_COUNT]
word_set = {w for w, c in word_count}
word2index = {w: i for i, w in enumerate(word_set)}
masters_list = [[w for w in words if w in word_set] for words in words_list]
MAX_LEN = 2000
lens = sorted([len(e) for e in masters_list])
lens = [e if e < MAX_LEN else MAX_LEN for e in lens]
CONTAINS_PER = 0.85
avg_len = int(sum(lens) / len(lens))
padding_word = len(word_set)
padding_len = lens[int(CONTAINS_PER * len(words_list))]
print("avg_len", avg_len)
print("padding_idx", padding_word)
print("padding_len", padding_len)
print("get masters_list")
masters_list = [[word2index[w] for w in words if w in word_set] for words in words_list]
masters_list = [padding_seq(seq, padding_word, padding_len) for seq in masters_list]
print("get salves_list_list")
WINDOW_SIZE = 2
salves_list_list = [[[masters[j] if 0 <= j < len(masters) else padding_word
for j in range(i - WINDOW_SIZE, i + WINDOW_SIZE + 1) if i != j]
for i in range(len(masters))] for masters in masters_list]
print("get edge_set")
edge_set = {str([master, salve]) for masters, salves_list in list(zip(masters_list, salves_list_list))
for master, salves in list(zip(masters, salves_list))
for salve in salves if master != padding_word and salve != padding_word}
edge2index = {e: i for i, e in enumerate(edge_set)}
padding_edge = len(edge_set)
print("get edges_list_list")
edges_list_list = [[[str([master, salve]) for salve in salves]
for master, salves in list(zip(masters, salves_list))]
for masters, salves_list in list(zip(masters_list, salves_list_list))]
edges_list_list = [[[edge2index[e] if e in edge_set else padding_edge for e in edges]
for edges in edges_list] for edges_list in edges_list_list]
print("get word2vector")
word2vector = get_word2vec(word2index)
print("save values")
np.save(TEMP_PATH + r"/masters_list.npy", masters_list)
np.save(TEMP_PATH + r"/salves_list_list.npy", salves_list_list)
np.save(TEMP_PATH + r"/edges_list_list.npy", edges_list_list)
np.save(TEMP_PATH + r"/targets.npy", targets)
np.save(TEMP_PATH + r"/word2vector.npy", word2vector)
joblib.dump(word2index, TEMP_PATH + r"/word2index.pkl")
joblib.dump(edge2index, TEMP_PATH + r"/edge2index.pkl")
joblib.dump(
{"MIN_COUNT": MIN_COUNT,
"MAX_LEN": MAX_LEN,
"CONTAINS_PER": CONTAINS_PER,
"WINDOW_SIZE": WINDOW_SIZE},
TEMP_PATH + f"/pad_len={padding_len},pad_word={padding_word},pad_edge={padding_edge}.pkl")
3.6train.py
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from model import TextLevelGNN
from dataset import MyDataset
from config import RECORD_PATH, TEMP_PATH
import numpy as np
def train_eval(cate, data_loader, model, optimizer, loss_func):
model.train() if cate == 'train' else model.eval()
acc, loss_sum = 0.0, 0.0
for i, (masters, salves_list, edges_list, target) in enumerate(data_loader):
masters, salves_list, edges_list, target = (
masters.cuda(),
salves_list.cuda(),
edges_list.cuda(),
target.cuda()
)
y = model(masters, salves_list, edges_list)
loss = loss_func(y, target)
if cate == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc += y.max(dim=1)[1].eq(target).sum().data
loss_sum += loss.data
acc = acc * 100 / len(data_loader.dataset)
loss_sum = loss_sum / len(data_loader)
return acc, loss_sum
if __name__ == '__main__':
num_words = 53867
num_edges = 3121369
num_classes = 20
embedding_dim = 300
start = 0
padding_len = 400
batch_size = 32
lr = 1e-3
weight_decay = 1e-4
print("init & load...")
train_loader = DataLoader(MyDataset("train", padding_len), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(MyDataset("test", padding_len), batch_size=batch_size)
word2vector = np.load(TEMP_PATH + r"/word2vector.npy")
model = TextLevelGNN(word2vector)
if start != 0: model.load_state_dict(torch.load(RECORD_PATH + '/model.{}.pth'.format(start)))
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
print("start...")
model = model.cuda()
for epoch in range(start + 1, 1000):
t1 = time.time()
train_acc, train_loss = train_eval('train', train_loader, model, optimizer, loss_func)
test_acc, test_loss = train_eval('test', test_loader, model, None, loss_func)
cost = time.time() - t1
torch.save(model.state_dict(), RECORD_PATH + '/model.{}.pth'.format(epoch))
print("epoch=%s, cost=%.2f, train:[loss=%.4f, acc=%.2f%%], test:[loss=%.4f, acc=%.2f%%]"
% (epoch, cost, train_loss, train_acc, test_loss, test_acc))