数据集位于lda安装目录的tests文件夹中,包含三个文件:reuters.ldac, reuters.titles, reuters.tokens。
reuters.titles包含了395个文档的标题
reuters.tokens包含了这395个文档中出现的所有单词,总共是4258个
reuters.ldac有395行,第i行代表第i个文档中各个词汇出现的频率。以第0行为例,第0行代表的是第0个文档,从reuters.titles中可查到该文档的标题为“UK: Prince Charles spearheads British royal revolution. LONDON 1996-08-20”。
# !/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import lda
import lda.datasets
from pprint import pprint
if __name__ == "__main__":
# document-term matrix
X = lda.datasets.load_reuters()
print(("type(X): {}".format(type(X))))
print(("shape: {}\n".format(X.shape)))
print((X[:10, :10]))
# the vocab
vocab = lda.datasets.load_reuters_vocab()
print(("type(vocab): {}".format(type(vocab))))
print(("len(vocab): {}\n".format(len(vocab))))
print((vocab[:10]))
# titles for each story
titles = lda.datasets.load_reuters_titles()
print(("type(titles): {}".format(type(titles))))
print(("len(titles): {}\n".format(len(titles))))
pprint(titles[:10])
print('LDA start ----')
topic_num = 20
model = lda.LDA(n_topics=topic_num, n_iter=800, random_state=1)
model.fit(X)
# topic-word
topic_word = model.topic_word_
print(("type(topic_word): {}".format(type(topic_word))))
print(("shape: {}".format(topic_word.shape)))
print((vocab[:5]))
print((topic_word[:, :5]))
# Print Topic distribution
n = 7
for i, topic_dist in enumerate(topic_word):
topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(n + 1):-1]
print(('*Topic {}\n- {}'.format(i, ' '.join(topic_words))))
# Document - topic
doc_topic = model.doc_topic_
print(("type(doc_topic): {}".format(type(doc_topic))))
print(("shape: {}".format(doc_topic.shape)))
for i in range(10):
topic_most_pr = doc_topic[i].argmax()
print(("文档: {} 主题: {} value: {}".format(i, topic_most_pr, doc_topic[i][topic_most_pr])))
mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# Topic - word
plt.figure(figsize=(8, 7))
# f, ax = plt.subplots(5, 1, sharex=True)
for i, k in enumerate([0, 5, 9, 14, 19]):
ax = plt.subplot(5, 1, i+1)
ax.plot(topic_word[k, :], 'r-')
ax.set_xlim(-50, 4350) # [0,4258]
ax.set_ylim(0, 0.08)
ax.set_ylabel("概率")
ax.set_title("主题 {}".format(k))
plt.xlabel("词", fontsize=14)
plt.tight_layout()
plt.suptitle('主题的词分布', fontsize=18)
plt.subplots_adjust(top=0.9)
plt.show()
# Document - Topic
plt.figure(figsize=(8, 7))
# f, ax= plt.subplots(5, 1, figsize=(8, 6), sharex=True)
for i, k in enumerate([1, 3, 4, 8, 9]):
ax = plt.subplot(5, 1, i+1)
ax.stem(doc_topic[k, :], linefmt='g-', markerfmt='ro')
ax.set_xlim(-1, topic_num+1)
ax.set_ylim(0, 1)
ax.set_ylabel("概率")
ax.set_title("文档 {}".format(k))
plt.xlabel("主题", fontsize=14)
plt.suptitle('文档的主题分布', fontsize=18)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
文档: 0 主题: 8 value: 0.5308695652173914
文档: 1 主题: 13 value: 0.25434782608695655
文档: 2 主题: 14 value: 0.6489539748953975
文档: 3 主题: 8 value: 0.4789473684210527
文档: 4 主题: 14 value: 0.7568265682656825
文档: 5 主题: 14 value: 0.844097222222222
文档: 6 主题: 14 value: 0.8540404040404042
文档: 7 主题: 14 value: 0.8638225255972695
文档: 8 主题: 14 value: 0.7388461538461537
文档: 9 主题: 8 value: 0.48157894736842105