这是代码:
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""CNN/DailyMail Summarization dataset, non-anonymized version."""
import hashlib
import os
import json
import logger
import datasets
DM_SINGLE_CLOSE_QUOTE = "\u2019" # unicode
DM_DOUBLE_CLOSE_QUOTE = "\u201d"
# acceptable ways to end a sentence
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', DM_SINGLE_CLOSE_QUOTE, DM_DOUBLE_CLOSE_QUOTE, ")"]
def _read_text_file_path(path):
with open(path, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f]
return lines
def _get_url_hashes(path):
"""Get hashes of urls in file."""
urls = _read_text_file_path(path)
def url_hash(u):
h = hashlib.sha1()
try:
u = u.encode("utf-8")
except UnicodeDecodeError:
logger.error("Cannot hash url: %s", u)
h.update(u)
return h.hexdigest()
return {url_hash(u) for u in urls}
def _get_hash_from_path(p):
"""Extract hash from path."""
return os.path.splitext(os.path.basename(p))[0]
def _read_text_file(text_file):
lines = []
with open(text_file, "r", encoding='utf-8') as f:
for line in f:
lines.append(line.strip())
return lines
def _get_art_abs(story_file, tfds_version):
"""Get abstract (highlights) and article from a story file path."""
# Based on https://github.com/abisee/cnn-dailymail/blob/master/
# make_datafiles.py
lines = _read_text_file(story_file)
# The github code lowercase the text and we removed it in 3.0.0.
# Put periods on the ends of lines that are missing them
# (this is a problem in the dataset because many image captions don't end in
# periods; consequently they end up in the body of the article as run-on
# sentences)
def fix_missing_period(line):
"""Adds a period to a line that is missing a period."""
if "@highlight" in line:
return line
if not line:
return line
if line[-1] in END_TOKENS:
return line
return line + " ."
lines = [fix_missing_period(line) for line in lines]
# Separate out article and abstract sentences
article_lines = []
highlights = []
next_is_highlight = False
for line in lines:
if not line:
continue # empty line
elif line.startswith("@highlight"):
next_is_highlight = True
elif next_is_highlight:
highlights.append(line)
else:
article_lines.append(line)
# Make article into a single string
article = " ".join(article_lines)
if tfds_version >= "2.0.0":
abstract = "\n".join(highlights)
else:
abstract = " ".join(highlights)
return article, abstract
def _generate_examples(urls_file, files, config_version):
urls = _get_url_hashes(urls_file)
idx = 0
for file in os.listdir(files):
hash_from_path = _get_hash_from_path(files+file)
if hash_from_path in urls:
article, highlights = _get_art_abs(files+file, config_version)
if not article or not highlights:
continue
yield idx, {
"instruction": 'Please help me to summary this article.',
"input": article,
"output": highlights
}
idx += 1
def main():
json_file_path = './result_train_cnn.json'
jsonpath = open(json_file_path, mode='w')
train = []
generator_cnn = _generate_examples("./url_lists/cnn_wayback_training_urls.txt","./cnn/stories/",'3.0.0')
for i in generator_cnn:
train.append(i)
generator_dl = _generate_examples(***,***,'3.0.0')
for i in generator_dl:
train.append(i)
# for file in os.listdir("./cnn/stories"):
# train_cnn.append(_generate_examples("./cnn/stories/" + file))
json.dump(train, jsonpath)
if __name__ == '__main__':
main()