这篇文章是Textsum数据处理的续篇,主要记录了再次实验中遇到的问题,以及对实验的不断改进和完善的过程。
1 再次运行模型
由于词频统计脚本实在是太慢了,在它统计完三分之一,即约三万条数据的时候,我决定开始重新运行模型。我们将处理好的9w条CNN数据,取前34600条作为训练样本,生成新的文件。将其拷贝到Textsum工作空间下的data/下,重命名为data。同时将对应的story.vocab.new词表文件拷贝到同一目录下。
吸取之前实验的经验,这次要将最大epoch设置的小一些,使用以下命令。不要忘了将原来实验的日志文档和中间结果备份起来,否则会被覆盖掉。
$ nohup bazel-bin/textsum/seq2seq_attention \
--mode=train \
--article_key=article \
--abstract_key=abstract \
--data_path=data/data \
--vocab_path=data/vocab \
--log_root=textsum/log_root \
--train_dir=textsum/log_root/train \
--max_run_steps=10000 \
> train_log 2>&1 &
但在此时我遇到了Bug:Memory Error
网上说这主要是因为Python的垃圾回收机制不够智能之类的,导致了内存不足。但实际上是因为我没有将文本文件转换为二进制文件,而直接输入给了模型,导致某个步骤出现了死循环。这里一定要细心,使用上篇文章讲过的方法生成data的二进制格式文件,重新运行。
这时又出现了Drop Example的情况:
WARNING:tensorflow:Drop an example - too long.
实际上,在模型的batch_reader.py文件中的Batcher._FillInputQueue()中有对数据文段长度的过滤:
# Filter out too-short input
if (len(enc_inputs) < self._hps.min_input_len or len(dec_inputs) < self._hps.min_input_len):
tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d', len(enc_inputs), len(dec_inputs))
continue
# If we're not truncating input, throw out too-long input
if not self._truncate_input:
if (len(enc_inputs) > self._hps.enc_timesteps or len(dec_inputs) > self._hps.dec_timesteps):
tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d', len(enc_inputs), len(dec_inputs))
continue
可以看到这里对Abstract最大长度限制为hps.dec_timesteps
,而这个值的定义在seq2seq_attention.py中:
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
大致衡量一下,超出范围的abstract的长度大都为50以下,所以我们可以将dec_timesteps修改为50。在我们使用真正的新闻标题作为摘要后就基本不会出现这个问题。在修改了源码后一定不要忘记重新使用bazel进行编译。
$ bazel build -c opt textsum/...
另外我这里重新查了资料,如果加上–config=cuda参数之所以会出现以下问题:
INFO: Options provided by the client:
Inherited 'common' options: --isatty=1 --terminal_columns=109
ERROR: Config value cuda is not defined in any .rc file
INFO: Invocation ID: ed698fdb-cd68-4831-9614-f311901731cf
是因为bazel的版本问题与tensorflow不兼容,当然我当时安装的tensorflow就不是GPU版的,也不可能对,哈哈。参考这篇博客:源码编译安装tensorflow 1.8
接下来开始运行模型,我又出现了以下的神奇BUG:
Caused by op u'save/Assign_28', defined at:
File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 213, in <module>
tf.app.run()
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 196, in main
_Train(model, batcher)
File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 86, in _Train
saver = tf.train.Saver()
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1051, in __init__
self.build()
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1081, in build
restore_sequentially=self._restore_sequentially)
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 675, in build
restore_sequentially, reshape)
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 414, in _AddRestoreOps
assign_ops.append(saveable.restore(tensors, shapes))
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 155, in restore
self.op.get_shape().is_fully_defined())
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/ops/gen_state_ops.py", line 47, in assign
use_locking=use_locking, name=name)
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
op_def=op_def)
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [266993] rhs shape= [10003]
[[Node: save/Assign_28 = Assign[T=DT_FLOAT, _class=["loc:@seq2seq/output_projection/v"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](seq2seq/output_projection/v, save/RestoreV2_28)]]
后来发现这是由于原来的log_root产生了干扰,在重新运行程序时,可能会将原有的断点加载进去,而原来的模型和现在的模型的参数维度不匹配,所以就会报错,只需要将原有的log_root下的文件都删除掉即可:
$ rm -rf textsum/log_root/*
后来发现这是由于原来的log_root产生了干扰,在重新运行程序时,可能会将原有的断点加载进去,而原来的模型和现在的模型的参数维度不匹配,所以就会报错,只需要将原有的log_root下的文件都删除掉即可:
$ rm -rf textsum/log_root/*
模型运行结束发现loss值只能最小减小到0.5左右,之前的toy data可以收敛到0.00001,可见模型并没有收敛,所以将--max_run_step
设置为10000并不合理,下次实验将会尝试100000。
2 寻找文段标题
在以上的实验中也能看出,很多我们称之为abstract的第一段,实际很长,并不能作为摘要,这个假设是很失败的,所以我们需要寻找到文段的标题,作为真正的摘要。
进入DMQA的官网下载CNN stories的html源码。之前之所以没有直接这么做是因为这个网站是纽约大学的,这个文件总共1.4GB,要翻墙下载这么大的文件谈何容易。但现在还是需要尝试一下。
我使用了Google的插件Setup VPN,当然下载这个插件也需要翻Google商店的墙,具体怎么做不方便透露,可以私下联系我。最初下载的时候,链接很不稳定,网速大概只有几十KB每秒,但更大的困难在于链接会断,最后只下载下来100MB左右的文件。好在压缩包是多文件打包的,所以我把压缩包传到服务器上,可以查看文件的内容。
查看CNN stories html的网页样例:Russian bomber buzzes U.S. aircraft carrier
可以发现这就是普通的新闻网站啊,从源码中寻找标题,有两处是匹配的:
<title>Russian bomber buzzes U.S. aircraft carrier - CNN.com</title>
<H1> Russian bomber buzzes U.S. aircraft carrier</H1>
这里你选择哪一种都可以,它们的正则表达式匹配方式对应如下:
r'<title>([^<]+)- CNN.com</title>'
r'<H1>([\s\w\.-:]+)</H1>'
很显然使用第一种方式更保险一些,之所以第二种没有用[^<]+
的方式,是因为<H1></H1>
标签之间有可能有其他控制格式的标签,但我们又很难穷举标题中可能含有的字符,所以还是使用第一种匹配方式更加稳妥。
3 修改数据格式转换脚本
我们可以修改原始脚本而不需编写从html文档解析文段的脚本,这是有一个前提的,就是我们能够将标题与已有的文段数据对应起来。很幸运,我们观察story的文件名和html的文件名,相同文段的文件名中的ID是重合的,例如:
101309778895e6ca14fca581352735a07b9fc6f7.story
101309778895e6ca14fca581352735a07b9fc6f7.html
那么我们需要将两份数据进行交叉对比,选出能够成对的文件来,并将原来的脚本(参见上一篇博客)中的abstract的引用修改成从html文件获取。但这里也不是一帆风顺的,我们遇到了类似于下面的经典BUG:
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa0 in position 41213: invalid start byte
utf-8
,utf-16
,ascii
和gbk
之类的都不能解码。这时就用到了chardet这个Python工具包。它可以在很大的编码列表中尝试解码指定文件,并通过统计数据给出对编码的猜测以及相应的置信度。我们可以向下面这样使用它:
def check_encoding(file_name):
file = open(file_name, 'rb')
info = chardet.detect(file.read())
return info["encoding"]
info变量的格式如下:
{'confidence': 0.73, 'language': '', 'encoding': 'ISO-8859-1'}
可见我们的文件的编码很有可能是ISO-8859-1
,虽然置信度不高。经过尝试,解码成功。我们修改后的脚本全文如下:
# data_convert_new.py
# -*- coding: utf-8
import struct, sys, glob, random, re
from nltk.tokenize import sent_tokenize, word_tokenize
from tensorflow.core.example import example_pb2
def get_abstract(html_file):
html = open(html_file, 'rb')
doc = html.read().decode("ISO-8859-1")
abstract = re.findall(
re.compile(r'<title>([^<]+)- CNN.com</title>'), doc)[0]
abstract = re.sub(r'[^\x00-\x7F]+', '*', abstract.replace('=', '*'))
return abstract.strip()
def para_tokenize(para):
sentences = []
para_format = "<p> {} </p>"
sent_format = "<s> {} </s>"
for sentence in sent_tokenize(para):
sentences.append(sent_format.format(' '.join(word_tokenize(sentence))))
return para_format.format(' '.join(sentences))
def raw2text(id, out_file):
doc_format = "<d> {} </d>"
dat_format = "abstract={}\tarticle={}\tpublisher=CNN\r\n"
text = open("raw_data/cnn/stories/" + id + ".story", 'r').readlines()
writer = open(out_file, 'a')
index = len(text)
for i, t in enumerate(text):
if t.startswith("@highlight"):
index = i
break
text = text[:index]
for i in range(index):
text[i] = re.sub(r'[^\x00-\x7F]+', '*', text[i].replace('=', '*'))
abstract = doc_format.format(para_tokenize(get_abstract("raw_data/cnn/stories/" + id + ".html").strip()))
article = doc_format.format(' '.join([para_tokenize(line.strip()) for line in text[1:] if len(line) > 2]))
writer.write(dat_format.format(abstract, article))
writer.close()
def text2bin(in_file, out_file):
inputs = open(in_file, 'r').readlines()
writer = open(out_file, 'wb')
for inp in inputs:
tf_example = example_pb2.Example()
for feature in inp.strip().split('\t'):
(k, v) = feature.split('=')
tf_example.features.feature[k].bytes_list.value.extend([v])
tf_example_str = tf_example.SerializeToString()
str_len = len(tf_example_str)
writer.write(struct.pack('q', str_len))
writer.write(struct.pack('%ds' % str_len, tf_example_str))
writer.close()
def main():
stories = glob.glob("raw_data/cnn/stories/*.story")
htmls = glob.glob("raw_data/cnn/stories/*.html")
dicts = [story.split('.')[0].split('/')[-1] for story in stories]
available_ids = []
for html in htmls:
id = html.split('.')[0].split('/')[-1]
if id in dicts:
available_ids.append(id)
del stories[:]
del htmls[:]
del dicts[:]
random.shuffle(available_ids)
i, lenth = 0, len(available_ids)
if lenth == 0:
print("No Match!")
else:
for id in available_ids:
raw2text(id, "story.final.txt")
i += 1
if i % 10 == 0:
print("%d / %d" % (i, lenth))
text2bin("story.final.txt", "story.final.bin")
if __name__ == '__main__':
main()
4 修改词频统计脚本
为了加快词频统计,我使用分批次统计最后合并的方法。这里实际上是可以用多线程的,但是我对python的多线程并不熟悉,所以就使用了“多进程”。首先,编写一个能够根据参数统计某段语料的词频的脚本:
import re, sys
def word_count(text, vocab):
for word in text.split():
word = re.sub(r'[^\x00-\x7F]+', '*', word)
if word in vocab.keys():
vocab[word] += 1
else:
vocab[word] = 1
def main():
vocab = {}
if len(sys.argv) < 4:
print("Lack of parameter!")
return
in_file = sys.argv[1]
start_line = int(sys.argv[2])
end_line = int(sys.argv[3])
out_file = in_file + "." + sys.argv[2] + "_" + sys.argv[3]
inputs = open(in_file, 'r').readlines()[start_line: end_line]
total = len(inputs)
i = 0
for line in inputs:
for feature in line.strip().split('\t')[:-1]:
word_count(feature.split('=')[1], vocab)
with open(out_file, "w") as writer:
for word, freq in vocab.items():
writer.write(word + " " + str(freq) + "\r\n")
i += 1
if i % 20 == 0:
print("%d out of %d done" % (i, total))
if __name__ == '__main__':
main()
在我kill掉原始词频统计程序的时候,根据log可以看出大概统计到了第34600行,那么我们使用以下命令进行试验:
$ python seg_count.py story.txt 34600 34700
可以看出100行的批次运行速度还是很快的,大概需要5秒。我们可以查看一下原文的文件的总行数。更多的查看行数操作请参见这篇博客:linux下统计文本行数的各种方法
$ sed -n "$=" story.txt
我的文本文档的长度是93488,考虑合理性,接下来我将使用1000行作为批次大小,每十个命令为一组进行运行。将原始的词频统计结果重名名为story.txt.merged
,那么可以看到程序运行后的文件列表大致如下:
story.txt
story.txt.34800_34900
story.txt.34800_35800
story.txt.35800_36800
story.txt.36800_37800
story.txt.37800_38800
story.txt.38800_39800
story.txt.39800_40800
story.txt.40800_41800
story.txt.41800_42800
story.txt.42800_43800
story.txt.merged
保证以story.txt开头的没有其他不相关的文件。下面开始编写合并程序:
import sys, os, glob
def merge_dict(file_name):
vocab = {
'<p>' : 0,
'</p>' : 0,
'<s>' : 0,
'</s>' : 0,
'<UNK>': 0,
'<PAD>': 0,
'<d>' : 0,
'</d>' : 0,
}
file_list = glob.glob(file_name + ".*")
merged = file_name + ".merged"
if merged in file_list:
file_list.remove(merged)
file_list.append(merged)
total = len(file_list)
i = 0
for file in file_list:
with open(file, 'r') as reader:
for line in reader.readlines():
item = line.strip().split()
if len(item) < 2:
continue
if item[0] in vocab.keys():
vocab[item[0]] += int(item[1])
else:
vocab[item[0]] = int(item[1])
i += 1
if i % 5 == 0:
print("Have Merged %d out of %d" % (i, total))
s_start = vocab.pop('<s>')
s_end = vocab.pop('</s>')
vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
vocab.append(('<s>', s_start))
vocab.append(('</s>', s_end))
with open(merged, 'w') as writer:
for word, freq in vocab:
writer.write(word + " " + str(freq) + "\r\n")
file_list.remove(merged)
for file in file_list:
os.remove(file)
def main():
if len(sys.argv) < 2:
print("Lack of parameter!")
return
merge_dict(sys.argv[1])
if __name__ == '__main__':
main()
程序中将file_name.*
格式的文件加入合并列表,由于file_name.merged
是上一轮merge后的结果,文件较大,所以人为地将它放到列表最后。合并所有文件后不要忘记,将原有的临时文件都删除掉,以免下次合并时又合并一遍而导致词频重计。
除此之外,我们还有新加入的标题数据没有统计,我将编写一个脚本专门统计标题中出现的词,将其保存为story.txt.title
,以待最终的合并。这时我含有标题的正确数据已经保存在了story.final.txt
。
import re, sys
def title_words(text):
title = re.search(r'abstract=([^\t]+)', text).group(1)
res = re.findall(re.compile(r'<s> ([^<]+)</s>'), title)
to_return = []
if not res:
print(text[:80])
else:
for sentence in res:
to_return.extend(sentence.strip().split())
return to_return
def main():
vocab = {}
file_name = sys.argv[1]
file = open(file_name, 'r').readlines()
i, total = 0, len(file)
for line in file:
for word in title_words(line):
if word in vocab.keys():
vocab[word] += 1
else:
vocab[word] = 1
i += 1
if i % 100 == 0:
print("%d out of %d done" % (i, total))
with open(file_name + ".title", 'w') as writer:
for word, freq in vocab.items():
writer.write(word + " " + str(freq) + "\r\n")
print("write done")
if __name__ == "__main__":
main()
这里有一点需要注意,标题也未必是一句话,我们只能肯定它是一段话。通过这个小失误,我意外地发现NLTK的分句功能会将人名,诸如T.H.Atlantis
中的.
当做英文句号而将一句话切分成两句。另外词表中也包含了大量的数字,吃一堑长一智,下次应该进行更细致的预处理,将人名中的.
替换掉,将数字统一起来。
这段代码运行的很快,由于我们的正则匹配的外层是search匹配而不是findall,而所有的标题都恰好在开头的地方,所以摘取title中的句子很快。而title又都很短,所以分词也很快。
5 再战标题数据下载
100MB的数据比起1.4GB的原始数据,损失量简直无法忍受,所以我决定再战。那么究竟是什么限制了网络连接呢。就算网速慢,只要连接稳定,我用服务器下上一天一夜也无所谓。实际上我在使用wget
的时候还出现了如下小插曲:
$ wget https://public.boxcloud.com/d/1/b1!8SakkX6R-6MBeWd8pD-33E3NTavnhqpOlgmETnioeQyX_codFS0DXWjRzjkEUWNfhBjH10c_C13EqDh4Hp5dEYX1usAfkYBkX4LCC9qwxionSTEIDEoNG5uaqpgkD0B0MJs2-IlaVOIWtaGIdVARiW8lizXxLJw_LSHlMZKGnwDy6BckLILbhRLA1l7EZYh7EYMq5VVTbt95rePcoUOXVUvtlHg9UbXSSjAEiwsvsmiJB_K3hcnvMNplaGpjfRgz0wrpDFBQMTnm_Rng6_giN4vOodxSQsEI1pNeKtXszgIjW0HFd00qhBdRI1eZC5k4AZ6qoQ-SjaqAWc9_lqthiFZk95YFlvQ0PLZJJO0mUBDnWrh-Zbbzc4XfXG4fqrviMxT4-ge8fOftKJ-M4Kd7f1Jn1nx3K21L_p5aHPzbvQHrRH17q-3eoGCahhO0ZwC5-3ecS6tTYd40VWa8RuglgAJFhMWWTnuaYlzgrBMnRY9yJQrG5E0wN6NiL-cTi2w846VrHXtczFdrqyqf9PKOVETnpRc3HjW0nC5-JGHdUiwIASgxkV31dpCIvxSRS4tiL0a39OIxx1v7ztPRtFKVfz-xNte0tNSRJgEs8-qwQEzqpjk0AJ0ahYsj0w8A_Wfx_wdprGQ5AMxRct_FpXCFc89rGc8qwW_y4H4RKBHzcrT03ZxS4NzbwczAuSVrJxDILF5mmbxdEfOoLUwKS_yfFd1DHOA6NHPiwjlb2inmBzhgnzadi34zpC1efSOxlIR_4RlPa14N4Lm6YxUQlRJ4_wzz0YkaiToskTK87HDGrO5SicfRa6btv26pHyImwQRXar5VYhusJiiPq8GWgG26l3EBxcQfSPS5Qv1eTkxVrK_tPSOm2zalCaAirAcYTXpsIhJzF0upgT1WL69N3hPs0Mz9qwVAQ2lC2bC88wZmMoxTzsokwasLWHdd9RQTL9gR6N-twNDUuEof2e122wIpFvQBC-LAiUMQH4NRRzd-KhtXSqC_0LolgBMVnJWVbgboDd9jA2VqhH8UZvEBXxqqO16OZAYlVOZ27PWp7ajFNX_ChSMTd5zlbIdFjZx8boNk7sD74i-bJxiBtPPbHFhvIFyXFVAWwVkHX_aoX6wGuvmVJJ1NnXW0LFdnrNdfWEyc7VsCD2KkT1coDHBoKbvU448BxtX-x5gu5YoPIPQIbyyJglcmSMVrjifaLqISvQEdOfY6uaEZxl39ADbY0OWwTHRD7oicT3Zhhqxx6o9vMb_wyPFZRo5S4XLFWjpxe1TA5suYKo7GmOcj0HuVGQ../download
-bash: !8: event not found
后来我发现!
属于特殊符号,需要在前面加上\
转义。那么一个纽约大学的网站为什么链接会如此的不干净呢?实际上这并不是原始的链接,而是挂载VPN后,建立的类似隧道的链接。所以VPN不稳定是决定下载链接不稳定的首要因素。
我猜想VPN的遵从一个分配原理,类似于操作系统的调度,它会给每个用户一定的稳定时间,然后就涉及到了切换,甚至是不同代理服务器之间的切换。而下载需要一个稳定的链接,如果你的IP都变了,那链接肯定会断。所以我觉得VPN不好,就不可能上外网下载一个很大的文件。要是数据集被拆分成10个压缩包就好了,那我一定能下下来。
最后的最后,我只能花4美元租用了一个专业版的VPN,下载速度一度突破1MB/s。我使用N-best的办法,最终取了最大的文件,一共1GB,包含74807条数据。
另外,很有意思的是,VPN的速度有一个规律,相同国家优先,其次服务器地理位置越靠近自己,网速越快。所以最快的是香港,然后是日本。但刚刚下载完那1GB的数据,香港方面就把我禁了。