自动文本摘要经典模型TextSum运行录(三):继续战斗

这篇文章是Textsum数据处理的续篇,主要记录了再次实验中遇到的问题,以及对实验的不断改进和完善的过程。

1 再次运行模型

由于词频统计脚本实在是太慢了,在它统计完三分之一,即约三万条数据的时候,我决定开始重新运行模型。我们将处理好的9w条CNN数据,取前34600条作为训练样本,生成新的文件。将其拷贝到Textsum工作空间下的data/下,重命名为data。同时将对应的story.vocab.new词表文件拷贝到同一目录下。
吸取之前实验的经验,这次要将最大epoch设置的小一些,使用以下命令。不要忘了将原来实验的日志文档和中间结果备份起来,否则会被覆盖掉。

$ nohup bazel-bin/textsum/seq2seq_attention \
    --mode=train \
    --article_key=article \
    --abstract_key=abstract \
    --data_path=data/data \
    --vocab_path=data/vocab \
    --log_root=textsum/log_root \
    --train_dir=textsum/log_root/train \
    --max_run_steps=10000 \
    > train_log 2>&1 &

但在此时我遇到了Bug:Memory Error

网上说这主要是因为Python的垃圾回收机制不够智能之类的,导致了内存不足。但实际上是因为我没有将文本文件转换为二进制文件,而直接输入给了模型,导致某个步骤出现了死循环。这里一定要细心,使用上篇文章讲过的方法生成data的二进制格式文件,重新运行。

这时又出现了Drop Example的情况:

WARNING:tensorflow:Drop an example - too long.

实际上,在模型的batch_reader.py文件中的Batcher._FillInputQueue()中有对数据文段长度的过滤:

# Filter out too-short input
if (len(enc_inputs) < self._hps.min_input_len or len(dec_inputs) < self._hps.min_input_len):
    tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d', len(enc_inputs), len(dec_inputs))
    continue

# If we're not truncating input, throw out too-long input
if not self._truncate_input:
    if (len(enc_inputs) > self._hps.enc_timesteps or len(dec_inputs) > self._hps.dec_timesteps):
        tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d', len(enc_inputs), len(dec_inputs))
        continue

可以看到这里对Abstract最大长度限制为hps.dec_timesteps,而这个值的定义在seq2seq_attention.py中:

hps = seq2seq_attention_model.HParams(
        mode=FLAGS.mode,  # train, eval, decode
        min_lr=0.01,  # min learning rate.
        lr=0.15,  # learning rate
        batch_size=batch_size,
        enc_layers=4,
        enc_timesteps=120,
        dec_timesteps=30,
        min_input_len=2,  # discard articles/summaries < than this
        num_hidden=256,  # for rnn cell
        emb_dim=128,  # If 0, don't use embedding
        max_grad_norm=2,
        num_softmax_samples=4096)  # If 0, no sampled softmax.

大致衡量一下,超出范围的abstract的长度大都为50以下,所以我们可以将dec_timesteps修改为50。在我们使用真正的新闻标题作为摘要后就基本不会出现这个问题。在修改了源码后一定不要忘记重新使用bazel进行编译。

$ bazel build -c opt textsum/...

另外我这里重新查了资料,如果加上–config=cuda参数之所以会出现以下问题:

INFO: Options provided by the client:
  Inherited 'common' options: --isatty=1 --terminal_columns=109
ERROR: Config value cuda is not defined in any .rc file
INFO: Invocation ID: ed698fdb-cd68-4831-9614-f311901731cf

是因为bazel的版本问题与tensorflow不兼容,当然我当时安装的tensorflow就不是GPU版的,也不可能对,哈哈。参考这篇博客:源码编译安装tensorflow 1.8
接下来开始运行模型,我又出现了以下的神奇BUG:

Caused by op u'save/Assign_28', defined at:
  File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 213, in <module>
    tf.app.run()
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 196, in main
    _Train(model, batcher)
  File "/home/liushangyu/models/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 86, in _Train
    saver = tf.train.Saver()
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1051, in __init__
    self.build()
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1081, in build
    restore_sequentially=self._restore_sequentially)
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 675, in build
    restore_sequentially, reshape)
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 414, in _AddRestoreOps
    assign_ops.append(saveable.restore(tensors, shapes))
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 155, in restore
    self.op.get_shape().is_fully_defined())
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/ops/gen_state_ops.py", line 47, in assign
    use_locking=use_locking, name=name)
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/liushangyu/anaconda2/envs/tensorflow3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [266993] rhs shape= [10003]
         [[Node: save/Assign_28 = Assign[T=DT_FLOAT, _class=["loc:@seq2seq/output_projection/v"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](seq2seq/output_projection/v, save/RestoreV2_28)]]

后来发现这是由于原来的log_root产生了干扰,在重新运行程序时,可能会将原有的断点加载进去,而原来的模型和现在的模型的参数维度不匹配,所以就会报错,只需要将原有的log_root下的文件都删除掉即可:

$ rm -rf textsum/log_root/*

后来发现这是由于原来的log_root产生了干扰,在重新运行程序时,可能会将原有的断点加载进去,而原来的模型和现在的模型的参数维度不匹配,所以就会报错,只需要将原有的log_root下的文件都删除掉即可:

$ rm -rf textsum/log_root/*

模型运行结束发现loss值只能最小减小到0.5左右,之前的toy data可以收敛到0.00001,可见模型并没有收敛,所以将--max_run_step设置为10000并不合理,下次实验将会尝试100000。

2 寻找文段标题

在以上的实验中也能看出,很多我们称之为abstract的第一段,实际很长,并不能作为摘要,这个假设是很失败的,所以我们需要寻找到文段的标题,作为真正的摘要。

进入DMQA的官网下载CNN stories的html源码。之前之所以没有直接这么做是因为这个网站是纽约大学的,这个文件总共1.4GB,要翻墙下载这么大的文件谈何容易。但现在还是需要尝试一下。

我使用了Google的插件Setup VPN,当然下载这个插件也需要翻Google商店的墙,具体怎么做不方便透露,可以私下联系我。最初下载的时候,链接很不稳定,网速大概只有几十KB每秒,但更大的困难在于链接会断,最后只下载下来100MB左右的文件。好在压缩包是多文件打包的,所以我把压缩包传到服务器上,可以查看文件的内容。

查看CNN stories html的网页样例:Russian bomber buzzes U.S. aircraft carrier

可以发现这就是普通的新闻网站啊,从源码中寻找标题,有两处是匹配的:

<title>Russian bomber buzzes U.S. aircraft carrier - CNN.com</title>
<H1>  Russian bomber buzzes U.S. aircraft carrier</H1>

这里你选择哪一种都可以,它们的正则表达式匹配方式对应如下:

r'<title>([^<]+)- CNN.com</title>'
r'<H1>([\s\w\.-:]+)</H1>'

很显然使用第一种方式更保险一些,之所以第二种没有用[^<]+的方式,是因为<H1></H1>标签之间有可能有其他控制格式的标签,但我们又很难穷举标题中可能含有的字符,所以还是使用第一种匹配方式更加稳妥。

3 修改数据格式转换脚本

我们可以修改原始脚本而不需编写从html文档解析文段的脚本,这是有一个前提的,就是我们能够将标题与已有的文段数据对应起来。很幸运,我们观察story的文件名和html的文件名,相同文段的文件名中的ID是重合的,例如:

101309778895e6ca14fca581352735a07b9fc6f7.story
101309778895e6ca14fca581352735a07b9fc6f7.html

那么我们需要将两份数据进行交叉对比,选出能够成对的文件来,并将原来的脚本(参见上一篇博客)中的abstract的引用修改成从html文件获取。但这里也不是一帆风顺的,我们遇到了类似于下面的经典BUG:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa0 in position 41213: invalid start byte

utf-8utf-16asciigbk之类的都不能解码。这时就用到了chardet这个Python工具包。它可以在很大的编码列表中尝试解码指定文件,并通过统计数据给出对编码的猜测以及相应的置信度。我们可以向下面这样使用它:

def check_encoding(file_name):
    file = open(file_name, 'rb')
    info = chardet.detect(file.read())
    return info["encoding"]

info变量的格式如下:

{'confidence': 0.73, 'language': '', 'encoding': 'ISO-8859-1'}

可见我们的文件的编码很有可能是ISO-8859-1,虽然置信度不高。经过尝试,解码成功。我们修改后的脚本全文如下:

# data_convert_new.py
# -*- coding: utf-8
import struct, sys, glob, random, re
from nltk.tokenize import sent_tokenize, word_tokenize
from tensorflow.core.example import example_pb2

def get_abstract(html_file):
    html = open(html_file, 'rb')
    doc = html.read().decode("ISO-8859-1")
    abstract = re.findall(
        re.compile(r'<title>([^<]+)- CNN.com</title>'), doc)[0]
    abstract = re.sub(r'[^\x00-\x7F]+', '*', abstract.replace('=', '*'))
    return abstract.strip()

def para_tokenize(para):
    sentences = []
    para_format = "<p> {} </p>"
    sent_format = "<s> {} </s>"
    for sentence in sent_tokenize(para):
        sentences.append(sent_format.format(' '.join(word_tokenize(sentence))))
    return para_format.format(' '.join(sentences))

def raw2text(id, out_file):
    doc_format = "<d> {} </d>"
    dat_format = "abstract={}\tarticle={}\tpublisher=CNN\r\n"
    text = open("raw_data/cnn/stories/" + id + ".story", 'r').readlines()
    writer = open(out_file, 'a')
    index = len(text)
    for i, t in enumerate(text):
        if t.startswith("@highlight"):
            index = i
            break
    text = text[:index]
    for i in range(index):
        text[i] = re.sub(r'[^\x00-\x7F]+', '*', text[i].replace('=', '*'))
    abstract = doc_format.format(para_tokenize(get_abstract("raw_data/cnn/stories/" + id + ".html").strip()))
    article = doc_format.format(' '.join([para_tokenize(line.strip()) for line in text[1:] if len(line) > 2]))
    writer.write(dat_format.format(abstract, article))
    writer.close()

def text2bin(in_file, out_file):
    inputs = open(in_file, 'r').readlines()
    writer = open(out_file, 'wb')
    for inp in inputs:
        tf_example = example_pb2.Example()
        for feature in inp.strip().split('\t'):
            (k, v) = feature.split('=')
            tf_example.features.feature[k].bytes_list.value.extend([v])
        tf_example_str = tf_example.SerializeToString()
        str_len = len(tf_example_str)
        writer.write(struct.pack('q', str_len))
        writer.write(struct.pack('%ds' % str_len, tf_example_str))
    writer.close()

def main():
    stories = glob.glob("raw_data/cnn/stories/*.story")
    htmls = glob.glob("raw_data/cnn/stories/*.html")
    dicts = [story.split('.')[0].split('/')[-1] for story in stories]
    available_ids = []
    for html in htmls:
        id = html.split('.')[0].split('/')[-1]
        if id in dicts:
            available_ids.append(id)
    del stories[:]
    del htmls[:]
    del dicts[:]
    random.shuffle(available_ids)
    i, lenth = 0, len(available_ids)
    if lenth == 0:
        print("No Match!")
    else:    
        for id in available_ids:
            raw2text(id, "story.final.txt")
            i += 1
            if i % 10 == 0:
                print("%d / %d" % (i, lenth))
        text2bin("story.final.txt", "story.final.bin")

if __name__ == '__main__':
    main()

4 修改词频统计脚本

为了加快词频统计,我使用分批次统计最后合并的方法。这里实际上是可以用多线程的,但是我对python的多线程并不熟悉,所以就使用了“多进程”。首先,编写一个能够根据参数统计某段语料的词频的脚本:

import re, sys
def word_count(text, vocab):
    for word in text.split():
        word = re.sub(r'[^\x00-\x7F]+', '*', word)
        if word in vocab.keys():
            vocab[word] += 1
        else:
            vocab[word] = 1

def main():
    vocab = {}
    if len(sys.argv) < 4:
        print("Lack of parameter!")
        return
    in_file = sys.argv[1]
    start_line = int(sys.argv[2])
    end_line = int(sys.argv[3])
    out_file = in_file + "." + sys.argv[2] + "_" + sys.argv[3]
    inputs = open(in_file, 'r').readlines()[start_line: end_line]
    total = len(inputs)
    i = 0
    for line in inputs:
        for feature in line.strip().split('\t')[:-1]:
            word_count(feature.split('=')[1], vocab)
        with open(out_file, "w") as writer:
            for word, freq in vocab.items():
                writer.write(word + " " + str(freq) + "\r\n")
        i += 1
        if i % 20 == 0:
            print("%d out of %d done" % (i, total))
if __name__ == '__main__':
    main()

在我kill掉原始词频统计程序的时候,根据log可以看出大概统计到了第34600行,那么我们使用以下命令进行试验:

$ python seg_count.py story.txt 34600 34700

可以看出100行的批次运行速度还是很快的,大概需要5秒。我们可以查看一下原文的文件的总行数。更多的查看行数操作请参见这篇博客:linux下统计文本行数的各种方法

$ sed -n "$=" story.txt

我的文本文档的长度是93488,考虑合理性,接下来我将使用1000行作为批次大小,每十个命令为一组进行运行。将原始的词频统计结果重名名为story.txt.merged,那么可以看到程序运行后的文件列表大致如下:

story.txt
story.txt.34800_34900
story.txt.34800_35800
story.txt.35800_36800
story.txt.36800_37800
story.txt.37800_38800
story.txt.38800_39800
story.txt.39800_40800
story.txt.40800_41800
story.txt.41800_42800
story.txt.42800_43800
story.txt.merged

保证以story.txt开头的没有其他不相关的文件。下面开始编写合并程序:

import sys, os, glob
def merge_dict(file_name):
    vocab = {
        '<p>'  : 0,
        '</p>' : 0,
        '<s>'  : 0,
        '</s>' : 0,
        '<UNK>': 0,
        '<PAD>': 0,
        '<d>'  : 0,
        '</d>' : 0,
    }
    file_list = glob.glob(file_name + ".*")
    merged = file_name + ".merged"
    if merged in file_list:
        file_list.remove(merged)
        file_list.append(merged)
    total = len(file_list)
    i = 0
    for file in file_list:
        with open(file, 'r') as reader:
            for line in reader.readlines():
                item = line.strip().split()
                if len(item) < 2:
                    continue
                if item[0] in vocab.keys():
                    vocab[item[0]] += int(item[1])
                else:
                    vocab[item[0]] = int(item[1])
        i += 1
        if i % 5 == 0:
            print("Have Merged %d out of %d" % (i, total))
    s_start = vocab.pop('<s>')
    s_end = vocab.pop('</s>')
    vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
    vocab.append(('<s>', s_start))
    vocab.append(('</s>', s_end))
    with open(merged, 'w') as writer:
        for word, freq in vocab:
            writer.write(word + " " + str(freq) + "\r\n")
    file_list.remove(merged)
    for file in file_list:
        os.remove(file)

def main():
    if len(sys.argv) < 2:
        print("Lack of parameter!")
        return
    merge_dict(sys.argv[1])

if __name__ == '__main__':
    main()

程序中将file_name.*格式的文件加入合并列表,由于file_name.merged是上一轮merge后的结果,文件较大,所以人为地将它放到列表最后。合并所有文件后不要忘记,将原有的临时文件都删除掉,以免下次合并时又合并一遍而导致词频重计。

除此之外,我们还有新加入的标题数据没有统计,我将编写一个脚本专门统计标题中出现的词,将其保存为story.txt.title,以待最终的合并。这时我含有标题的正确数据已经保存在了story.final.txt

import re, sys
def title_words(text):
    title = re.search(r'abstract=([^\t]+)', text).group(1)
    res = re.findall(re.compile(r'<s> ([^<]+)</s>'), title)
    to_return = []
    if not res:
        print(text[:80])
    else:
        for sentence in res:
            to_return.extend(sentence.strip().split())
    return to_return

def main():
    vocab = {}
    file_name = sys.argv[1]
    file = open(file_name, 'r').readlines()
    i, total = 0, len(file)
    for line in file:
        for word in title_words(line):
            if word in vocab.keys():
                vocab[word] += 1
            else:
                vocab[word] = 1
        i += 1
        if i % 100 == 0:
            print("%d out of %d done" % (i, total))
    with open(file_name + ".title", 'w') as writer:
        for word, freq in vocab.items():
            writer.write(word + " " + str(freq) + "\r\n")
    print("write done")

if __name__ == "__main__":
    main()

这里有一点需要注意,标题也未必是一句话,我们只能肯定它是一段话。通过这个小失误,我意外地发现NLTK的分句功能会将人名,诸如T.H.Atlantis中的.当做英文句号而将一句话切分成两句。另外词表中也包含了大量的数字,吃一堑长一智,下次应该进行更细致的预处理,将人名中的.替换掉,将数字统一起来。

这段代码运行的很快,由于我们的正则匹配的外层是search匹配而不是findall,而所有的标题都恰好在开头的地方,所以摘取title中的句子很快。而title又都很短,所以分词也很快。

5 再战标题数据下载

100MB的数据比起1.4GB的原始数据,损失量简直无法忍受,所以我决定再战。那么究竟是什么限制了网络连接呢。就算网速慢,只要连接稳定,我用服务器下上一天一夜也无所谓。实际上我在使用wget的时候还出现了如下小插曲:

$ wget https://public.boxcloud.com/d/1/b1!8SakkX6R-6MBeWd8pD-33E3NTavnhqpOlgmETnioeQyX_codFS0DXWjRzjkEUWNfhBjH10c_C13EqDh4Hp5dEYX1usAfkYBkX4LCC9qwxionSTEIDEoNG5uaqpgkD0B0MJs2-IlaVOIWtaGIdVARiW8lizXxLJw_LSHlMZKGnwDy6BckLILbhRLA1l7EZYh7EYMq5VVTbt95rePcoUOXVUvtlHg9UbXSSjAEiwsvsmiJB_K3hcnvMNplaGpjfRgz0wrpDFBQMTnm_Rng6_giN4vOodxSQsEI1pNeKtXszgIjW0HFd00qhBdRI1eZC5k4AZ6qoQ-SjaqAWc9_lqthiFZk95YFlvQ0PLZJJO0mUBDnWrh-Zbbzc4XfXG4fqrviMxT4-ge8fOftKJ-M4Kd7f1Jn1nx3K21L_p5aHPzbvQHrRH17q-3eoGCahhO0ZwC5-3ecS6tTYd40VWa8RuglgAJFhMWWTnuaYlzgrBMnRY9yJQrG5E0wN6NiL-cTi2w846VrHXtczFdrqyqf9PKOVETnpRc3HjW0nC5-JGHdUiwIASgxkV31dpCIvxSRS4tiL0a39OIxx1v7ztPRtFKVfz-xNte0tNSRJgEs8-qwQEzqpjk0AJ0ahYsj0w8A_Wfx_wdprGQ5AMxRct_FpXCFc89rGc8qwW_y4H4RKBHzcrT03ZxS4NzbwczAuSVrJxDILF5mmbxdEfOoLUwKS_yfFd1DHOA6NHPiwjlb2inmBzhgnzadi34zpC1efSOxlIR_4RlPa14N4Lm6YxUQlRJ4_wzz0YkaiToskTK87HDGrO5SicfRa6btv26pHyImwQRXar5VYhusJiiPq8GWgG26l3EBxcQfSPS5Qv1eTkxVrK_tPSOm2zalCaAirAcYTXpsIhJzF0upgT1WL69N3hPs0Mz9qwVAQ2lC2bC88wZmMoxTzsokwasLWHdd9RQTL9gR6N-twNDUuEof2e122wIpFvQBC-LAiUMQH4NRRzd-KhtXSqC_0LolgBMVnJWVbgboDd9jA2VqhH8UZvEBXxqqO16OZAYlVOZ27PWp7ajFNX_ChSMTd5zlbIdFjZx8boNk7sD74i-bJxiBtPPbHFhvIFyXFVAWwVkHX_aoX6wGuvmVJJ1NnXW0LFdnrNdfWEyc7VsCD2KkT1coDHBoKbvU448BxtX-x5gu5YoPIPQIbyyJglcmSMVrjifaLqISvQEdOfY6uaEZxl39ADbY0OWwTHRD7oicT3Zhhqxx6o9vMb_wyPFZRo5S4XLFWjpxe1TA5suYKo7GmOcj0HuVGQ../download

-bash: !8: event not found

后来我发现!属于特殊符号,需要在前面加上\转义。那么一个纽约大学的网站为什么链接会如此的不干净呢?实际上这并不是原始的链接,而是挂载VPN后,建立的类似隧道的链接。所以VPN不稳定是决定下载链接不稳定的首要因素。

我猜想VPN的遵从一个分配原理,类似于操作系统的调度,它会给每个用户一定的稳定时间,然后就涉及到了切换,甚至是不同代理服务器之间的切换。而下载需要一个稳定的链接,如果你的IP都变了,那链接肯定会断。所以我觉得VPN不好,就不可能上外网下载一个很大的文件。要是数据集被拆分成10个压缩包就好了,那我一定能下下来。

最后的最后,我只能花4美元租用了一个专业版的VPN,下载速度一度突破1MB/s。我使用N-best的办法,最终取了最大的文件,一共1GB,包含74807条数据。

另外,很有意思的是,VPN的速度有一个规律,相同国家优先,其次服务器地理位置越靠近自己,网速越快。所以最快的是香港,然后是日本。但刚刚下载完那1GB的数据,香港方面就把我禁了。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值