progressive-generation-master代码记录【下载处理数据】(定义CNNDataset类)

下载处理数据主要是抽象用于训练seq2seq模型的数据集。

该类将处理位于指定文件夹中的文档。

预处理可以处理任何格式合理的文档。在CNN/DailyMail数据集中,它将提取故事和摘要。CNN /每日新闻:CNN/Daily News原始数据集从[1]下载。

故事存储在不同的文件中;摘要以句子的形式出现在故事的结尾,并以特殊的“@highlight”行作为前缀。

要处理数据,解压缩同一文件夹中的两个数据集,并将该文件夹的路径作为"data_dir参数传递。

格式化代码的灵感来自[2]。

[1] https://cs.nuu.edu/~kcho/

[2] GitHub - abisee/cnn-dailymail: Code to obtain the CNN / Daily Mail dataset (non-anonymized) for summarization

首先查看数据集

数据集为新闻故事加摘要

摘要前有@highlight作为前缀

定义一个类 来处理这些数据

class CNNDataset(Dataset):

    def __init__(self, ):
            
        path = 'F:\cwb5\progressive-generation-master\download\cnn1\cnn\stories'
        if not os.path.exists(path):
            os.system('perl download/gdown.pl '
                      '\'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ\' '
                      'cnn_stories.tgz')
            os.system('tar -xvf cnn_stories.tgz')

        assert os.path.isdir(path)

        self.documents = []
        story_filenames_list = sorted(os.listdir(path))
        for story_filename in story_filenames_list:
            if "summary" in story_filename:
                continue
            path_to_story = os.path.join(path, story_filename)
            if not os.path.isfile(path_to_story):
                continue
            self.documents.append(path_to_story)

    def __len__(self):
        """ Returns the number of documents. """
        return len(self.documents)

    def __getitem__(self, idx):
        document_path = self.documents[idx]
        document_name = document_path.split("/")[-1]
        with open(document_path, encoding="utf-8") as source:
            raw_story = source.read()
            story_lines, summary_lines = process_story(raw_story)
        return document_name, story_lines, summary_lines

如上为类的处理

分块来读

第一块为init初试化函数,在python中__xxx为私有变量,私有变量只可以从内部访问,外部不可以访问,__xxxx__为特殊变量,外部可以直接访问,而_xxx为一种约定的规矩,代表虽然可以直接访问但是最好把我视为私有变量不要随意访问。

 def __init__(self, ):
            
        path = 'F:\cwb5\progressive-generation-master\download\cnn1\cnn\stories'
        if not os.path.exists(path):
            os.system('perl download/gdown.pl '
                      '\'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ\' '
                      'cnn_stories.tgz')
            os.system('tar -xvf cnn_stories.tgz')
        assert os.path.isdir(path)

此处的path是本地下载数据集的路径,填写后使用os.path.exists进行一个检查,如果路径不存在,使用so.system,使用perl命令执行 download/gdown.pl脚本下载下方链接对应的文件,下载之后使用tar -xvf 对cnn_stories.tgz进行一个解压,至此云端的文件被下载到本地。

assert os.path.isdir(path)

使用assert来判定路径下有没有文件,os.path.isdir会返回路径下的目录,如果为空则会报错

self.documents = []
        story_filenames_list = sorted(os.listdir(path))
        for story_filename in story_filenames_list:
            if "summary" in story_filename:
                continue
            path_to_story = os.path.join(path, story_filename)
            if not os.path.isfile(path_to_story):
                continue
            self.documents.append(path_to_story)

接上述初始化模块,self.documents=[],self只在类中出现,代表的是类本身,这里为CNNDataset添加一个属性变量documents,类型为[]列表。

story_filenames_list = sorted(os.listdir(path))

定义一个新闻故事名字列表,使用sorted函数对path的文件名字进行一个排序

sorted(iterable, key=None, reverse=False)
iterable : 可迭代对象
key : 排序函数,在sorted内部将可迭代对象中的每个元素传递给这个函数的参数,根据函数运算结果进行排序
reverse : 排序规则,reverse = True降序,reverse = False 升序(默认)

os.listdir()方法返回指定文件夹包含的文件或文件夹的名字的列表

path_to_story = os.path.join(path, story_filename)

建立一个路径列表path_to_story

for story_filename in story_filenames_list:

if "summary" in story_filename:
                continue
            path_to_story = os.path.join(path, story_filename)

使用for 遍历文件名字使用os.path.join(path, story_filename)合并路径

将数据集的所有路径存储在路径列表内

而后使用if与os.path.isfile判断路径是否为文件

self.documents.append(path_to_story)

使用append将path_to_story添加到CNNDatast的属性类documents中去

至此类中的第一个初始化函数完成

第二个长度函数

len返回documents 的个数

    def __len__(self):
        """ Returns the number of documents. """
        return len(self.documents)

def __getitem__(self, idx):
        document_path = self.documents[idx]
        document_name = document_path.split("/")[-1]
        with open(document_path, encoding="utf-8") as source:
            raw_story = source.read()
            story_lines, summary_lines = process_story(raw_story)
        return document_name, story_lines, summary_lines

documents_path为documents的带索引文件,documents_name将路径使用split分割

[-1]为最后一个,即为名称

with open(document_path, encoding="utf-8") as source:
            raw_story = source.read()

将文件打开,使raw_story为读取的内容,将raw_story传入 函数process_story中,并将返回值赋予story_lines,summary_lines

返回文档名称,故事行,总结行

初始化函数结束

———————————————————————————————————————————

例:F:/wenjian/mingcheng

分割后['F','wenjian','mingcheng']

取[-1]后为 mingcheng

———————————————————————————————————————————

--------------------------------------------------------------------------------------------------------------------------------

os.path.join用法

import os
 
Path1 = 'home'
Path2 = 'develop'


Path20 = os.path.join(Path1, Path2)

print('Path20 = ',Path20)

输出:

Path20 = home\develop\

---------------------------------------------------------------------------------------------------------------------------------

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值