下载处理数据主要是抽象用于训练seq2seq模型的数据集。
该类将处理位于指定文件夹中的文档。
预处理可以处理任何格式合理的文档。在CNN/DailyMail数据集中,它将提取故事和摘要。CNN /每日新闻:CNN/Daily News原始数据集从[1]下载。
故事存储在不同的文件中;摘要以句子的形式出现在故事的结尾,并以特殊的“@highlight”行作为前缀。
要处理数据,解压缩同一文件夹中的两个数据集,并将该文件夹的路径作为"data_dir参数传递。
格式化代码的灵感来自[2]。
首先查看数据集
数据集为新闻故事加摘要
摘要前有@highlight作为前缀
定义一个类 来处理这些数据
class CNNDataset(Dataset):
def __init__(self, ):
path = 'F:\cwb5\progressive-generation-master\download\cnn1\cnn\stories'
if not os.path.exists(path):
os.system('perl download/gdown.pl '
'\'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ\' '
'cnn_stories.tgz')
os.system('tar -xvf cnn_stories.tgz')
assert os.path.isdir(path)
self.documents = []
story_filenames_list = sorted(os.listdir(path))
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue
self.documents.append(path_to_story)
def __len__(self):
""" Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx):
document_path = self.documents[idx]
document_name = document_path.split("/")[-1]
with open(document_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return document_name, story_lines, summary_lines
如上为类的处理
分块来读
第一块为init初试化函数,在python中__xxx为私有变量,私有变量只可以从内部访问,外部不可以访问,__xxxx__为特殊变量,外部可以直接访问,而_xxx为一种约定的规矩,代表虽然可以直接访问但是最好把我视为私有变量不要随意访问。
def __init__(self, ):
path = 'F:\cwb5\progressive-generation-master\download\cnn1\cnn\stories'
if not os.path.exists(path):
os.system('perl download/gdown.pl '
'\'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ\' '
'cnn_stories.tgz')
os.system('tar -xvf cnn_stories.tgz')
assert os.path.isdir(path)
此处的path是本地下载数据集的路径,填写后使用os.path.exists进行一个检查,如果路径不存在,使用so.system,使用perl命令执行 download/gdown.pl脚本下载下方链接对应的文件,下载之后使用tar -xvf 对cnn_stories.tgz进行一个解压,至此云端的文件被下载到本地。
assert os.path.isdir(path)
使用assert来判定路径下有没有文件,os.path.isdir会返回路径下的目录,如果为空则会报错
self.documents = []
story_filenames_list = sorted(os.listdir(path))
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue
self.documents.append(path_to_story)
接上述初始化模块,self.documents=[],self只在类中出现,代表的是类本身,这里为CNNDataset添加一个属性变量documents,类型为[]列表。
story_filenames_list = sorted(os.listdir(path))
定义一个新闻故事名字列表,使用sorted函数对path的文件名字进行一个排序
sorted(iterable, key=None, reverse=False)
iterable : 可迭代对象
key : 排序函数,在sorted内部将可迭代对象中的每个元素传递给这个函数的参数,根据函数运算结果进行排序
reverse : 排序规则,reverse = True降序,reverse = False 升序(默认)
os.listdir()方法返回指定文件夹包含的文件或文件夹的名字的列表
path_to_story = os.path.join(path, story_filename)
建立一个路径列表path_to_story
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
使用for 遍历文件名字使用os.path.join(path, story_filename)合并路径
将数据集的所有路径存储在路径列表内
而后使用if与os.path.isfile判断路径是否为文件
self.documents.append(path_to_story)
使用append将path_to_story添加到CNNDatast的属性类documents中去
至此类中的第一个初始化函数完成
第二个长度函数
len返回documents 的个数
def __len__(self):
""" Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx):
document_path = self.documents[idx]
document_name = document_path.split("/")[-1]
with open(document_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return document_name, story_lines, summary_lines
documents_path为documents的带索引文件,documents_name将路径使用split分割
[-1]为最后一个,即为名称
with open(document_path, encoding="utf-8") as source:
raw_story = source.read()
将文件打开,使raw_story为读取的内容,将raw_story传入 函数process_story中,并将返回值赋予story_lines,summary_lines
返回文档名称,故事行,总结行
初始化函数结束
———————————————————————————————————————————
例:F:/wenjian/mingcheng
分割后['F','wenjian','mingcheng']
取[-1]后为 mingcheng
———————————————————————————————————————————
--------------------------------------------------------------------------------------------------------------------------------
os.path.join用法
import os
Path1 = 'home'
Path2 = 'develop'
Path20 = os.path.join(Path1, Path2)
print('Path20 = ',Path20)
输出:
Path20 = home\develop\
---------------------------------------------------------------------------------------------------------------------------------