Hierarchical Attention Network for Document Classification--tensorflow实现篇

本文档介绍了如何使用Tensorflow实现Hierarchical Attention Network(HAN)模型,适用于文本分类任务。作者首先介绍了数据集的获取和预处理,然后详细阐述了模型的构建,包括双向GRU和注意力机制。接着,描述了模型训练过程,提到了RNN模型中常见的梯度截断技术。最后,展示了初步的训练结果,并提供了其他相关资源链接。
摘要由CSDN通过智能技术生成

上周我们介绍了Hierarchical Attention Network for Document Classification这篇论文的模型架构,这周抽空用tensorflow实现了一下,接下来主要从代码的角度介绍如何实现用于文本分类的HAN模型。

数据集

首先介绍一下数据集,这篇论文中使用了几个比较大的数据集,包括IMDB电影评分,yelp餐馆评价等等。选定使用yelp2013之后,一开始找数据集的时候完全处于懵逼状态,所有相关的论文和资料里面出现的数据集下载链接都指向YELP官网,但是官网上怎么都找不到相关数据的下载,然后就各种搜感觉都搜不到==然后就好不容易在github上面找到了,MDZZ,我这都是在写什么,绝对不是在凑字数,单纯的吐槽数据不好找而已。链接如下:
https://github.com/rekiksab/Yelp/tree/master/yelp_challenge/yelp_phoenix_academic_dataset
这里面好像不止一个数据集,还有user,business等其他几个数据集,不过在这里用不到罢了。先来看一下数据集的格式,如下,每一行是一个评论的文本,是json格式保存的,主要有vote, user_id, review_id, stars, data, text, type, business_id几项,针对本任务,只需要使用stars评分和text评论内容即可。这里我选择先将相关的数据保存下来作为数据集。代码如下所示:

{"votes": {"funny": 0, "useful": 5, "cool": 2}, "user_id": "rLtl8ZkDX5vH5nAx9C3q5Q", "review_id": "fWKvX83p0-ka4JS3dc6E5A", "stars": 5, "date": "2011-01-26", "text": "My wife took me here on my birthday for breakfast and it was excellent.  The weather was perfect which made sitting outside overlooking their grounds an absolute pleasure.  Our waitress was excellent and our food arrived quickly on the semi-busy Saturday morning.  It looked like the place fills up pretty quickly so the earlier you get here the better.\n\nDo yourself a favor and get their Bloody Mary.  It was phenomenal and simply the best I've ever had.  I'm pretty sure they only use ingredients from their garden and blend them fresh when you order it.  It was amazing.\n\nWhile EVERYTHING on the menu looks excellent, I had the white truffle scrambled eggs vegetable skillet and it was tasty and delicious.  It came with 2 pieces of their griddled bread with was amazing and it absolutely made the meal complete.  It was the best \"toast\" I've ever had.\n\nAnyway, I can't wait to go back!", "type": "review", "business_id": "9yKzy9PApeiPPOUJEtnvkg"}

数据集的预处理操作,这里我做了一定的简化,将每条评论数据都转化为30*30的矩阵,其实可以不用这么规划,只需要将大于30的截断即可,小鱼30的不需要补全操作,只是后续需要给每个batch选定最大长度,然后获取每个样本大小,这部分我还没有太搞清楚,等之后有时间再看一看,把这个功能加上就行了。先这样凑合用==

#coding=utf-8
import json
import pickle
import nltk
from nltk.tokenize import WordPunctTokenizer
from collections import defaultdict

#使用nltk分词分句器
sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
word_tokenizer = WordPunctTokenizer()

#记录每个单词及其出现的频率
word_freq = defaultdict(int)

# 读取数据集,并进行分词,统计每个单词出现次数,保存在word freq中
with open('yelp_academic_dataset_review.json', 'rb') as f:
    for line in f:
        review = json.loads(line)
        words = word_tokenizer.tokenize(review['text'])
        for word in words:
            word_freq[word] += 1

    print "load finished"

# 将词频表保存下来
with open('word_freq.pickle', 'wb') as g:
    pickle.dump(word_freq, g)
    print len(word_freq)#159654
    print "word_freq save finished"

num_classes = 5
# 将词频排序,并去掉出现次数最小的3个
sort_words = list(sorted(word_freq.items(), key=lambda x:-x[1]))
print sort_words[:10], sort_words[-10:]

#构建vocablary,并将出现次数小于5的单词全部去除,视为UNKNOW
vocab = {}
i = 1
vocab['UNKNOW_TOKEN'] = 0
for word, freq in word_freq.items():
    if freq > 5:
        vocab[word] = i
        i 
  • 13
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 35
    评论
评论 35
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值