【动手学习pytorch笔记】37.4 BERT微调数据集

BERT微调数据集

自然语言推断任务:

主要研究 假设(hypothesis)是否可以从前提(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

  • 蕴涵(entailment):假设可以从前提中推断出来。
  • 矛盾(contradiction):假设的否定可以从前提中推断出来。
  • 中性(neutral):所有其他情况。

斯坦福自然语言推断(SNLI)数据集

由500000多个带标签的英语句子对组成的集合

import os
import re
import torch
from torch import nn
from d2l import torch as d2l

#@save

d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = "D:\environment\data\data\snli_1.0"

读取数据集

#@save
def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除我们不会使用的信息
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 用一个空格替换两个或多个连续的空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, encoding = 'utf-8') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] \
                in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('前提:', x0)
    print('假设:', x1)
    print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0

统计三个关系的数量

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
    print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

上面训练集,下面测试集,挺平均的

data[2]是标签,统计标签数量就行

#@save
class SNLIDataset(torch.utils.data.Dataset):
    """用于加载SNLI数据集的自定义数据集"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + \
                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

vocab需要与BERT预训练时的vocab保持一致,不然他不认识呀,所以下载与训练模型的时候一般都是下载模型和vocab

整理一下

#@save

def load_data_snli(batch_size, num_steps=50):
    """下载SNLI数据集并返回数据迭代器和词表"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = "D:\environment\data\data\snli_1.0"
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False)
    return train_iter, test_iter, train_set.vocab

看看大小

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples

18678
for X, Y in train_iter:
    print(X[0].shape)
    print(X[1].shape)
    print(Y.shape)
    break
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])

batchsize = 128

一个句子长度为50

最后说一下这一节的踩坑:

  • 和之前一节的数据加载一样

    torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers = num_workers )

    num_workers = num_workers,开多线程读取数据会报错

  • 下载解压数据集时报错

    OSError: [Errno 22] Invalid argument: '..\\data\\snli_1.0\\Icon\r
    

    报错的原因:SNLI数据集的压缩文件"snli_1.0.zip"里面有两个路径为“snli_1.0\Icon\r”和“’__MACOSX/snli_1.0/._Icon\r’”的文件,导致无法解析此路径进而导致整个文件无法解压。

    解决方法:手动解压之后把data_dir赋值为数据集解压后的路径

    data_dir = d2l.download_extract('SNLI')
    

    改成

    data_dir = "D:\environment\data\data\snli_1.0"
    
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Faster R-CNN是一种常用的目标检测算法,它结合了区域提取网络(Region Proposal Network,RPN)和分类网络来实现物体检测。在使用Faster R-CNN进行目标检测时,通常需要将模型的源码进行微调,以适应自己的数据集。 在PyTorch中,微调Faster R-CNN的源码需要以下几个步骤: 1. 数据集准备:首先需要准备自己的目标检测数据集。该数据集需要包含图片和对应的标签信息,标签信息通常包括物体的类别和边界框坐标。可以使用标注工具如LabelImg等进行标注,并将标注结果保存为一种格式,如VOC格式。 2. 获取源码:从PyTorch官方的GitHub仓库中获取Faster R-CNN源码。可以使用git命令行或者直接在浏览器上下载源码的压缩包。 3. 修改数据集加载:在源码中找到数据集加载部分的代码。可以通过修改已有的数据集类或者新建一个数据集类来加载自己的数据集。在数据集类中,需要定义数据集的路径、读取图片和标签的方法等。 4. 修改训练设置:在源码中找到训练设置部分的代码。根据自己的需求修改训练的batch size、学习率、训练轮数等参数。可以根据实际情况调整这些参数,以获得更好的训练效果。 5. 开始微调:在终端中切换到源码所在的目录,并执行训练指令,如"python train.py"。这将开始使用自己的数据集对Faster R-CNN进行微调。在微调过程中,可以观察训练日志和损失曲线,以评估训练的效果。 6. 模型保存:微调完成后,可以将训练得到的模型保存下来,以便后续的测试和推理使用。可以将模型保存为一个.pth文件,以便后续加载和使用。 通过以上步骤,我们可以使用PyTorch实现对Faster R-CNN的源码进行微调,以适应自己的目标检测数据集微调后的模型可以用于检测目标物体,并根据实际需要进行后续处理和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值