bert模型数据集加载方式

bert-base-chinese的分词
import pandas as pd
from transformers import BertTokenizerFast
import torch

model_name = "/data/transformer/classification_demo/bert-base-chinese"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
texts =[]
for i in range(2):
    texts.append("这是一段中文文本")
    
# 对文本进行编码
train_encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
train_encodings

train_encodings 结果:

{
'input_ids': [[101, 6821, 3221, 671, 3667, 704, 3152, 3152, 3315, 102], [101, 6821, 3221, 671, 3667, 704, 3152, 3152, 3315, 102]], 
'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
}

注意:model_name 这个目录里可以不需要模型文件和配置文件,因为这里只是分词,即只需要创建一个目录,包含v词表文件vocab.txt即可。

数据集构造

现记录一下PyTorch 的 torch.utils.data.Dataset 类的子类。Dataset 类是PyTorch框架中用于处理数据的基本组件,它允许用户定义自己的数据集类,以满足特定任务的需求。

Dataset是一个抽象基类,用于创建自定义数据集。它定义了两个核心方法:getitemlen,它们是所有数据集必须实现的方法。

类定子类:
重写 init 方法来初始化数据集,可能需要加载数据、解析数据等。
重写 getitem 方法来根据索引返回数据集中的一个样本,通常会包含数据的加载、解码等操作。
重写 len 方法来返回数据集中样本的数量。

import pandas as pd
from transformers import BertTokenizerFast
import torch


# 读取数据
df = pd.read_csv("./a.csv", encoding="utf-8")
# 过滤过短的content
df = df[df["content"].apply(lambda x: len(str(x))) > 10]
texts = df["content"][:100].tolist()
labels = df["punish_result"][:100].tolist()

# Hugging Face下载这个模型google-bert/bert-base-chinese
model_name = "./bert-base-chinese" 
# 加载分词器
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# 对文本进行编码
# truncation=True 文本超过max_length进行截断处理
# padding=True 文本不足max_length进行pad处理 补0
train_encodings = tokenizer(texts, truncation=True, padding=True, max_length=32)

# 封装数据为PyTorch Dataset
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        # item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        # 等价上面注释写法,for循环比较好理解
        item = {}
        for key, val in self.encodings.items():
            item[key] = torch.tensor(val[idx])

        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


train_dataset = TextDataset(train_encodings, labels)

for dta in train_dataset:
    print(dta)
    break

# 打印数据如下:

# {'input_ids': tensor([ 101, 1585,  511,  872, 1962, 8024, 2769, 6821, 6804, 3221,  976, 6858,
#         7599, 6392, 1906, 4638,  511, 2769, 2682, 7309,  671,  678, 8024, 1493,
#         6821, 6804, 7444, 6206, 6821,  671, 1779,  102]),
# 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0]),
# 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#         1, 1, 1, 1, 1, 1, 1, 1]),
# 'labels': tensor(0)
# }

上述代码主要通过加载bert-base-chinese模型的分词器处理原始数据,之后实现一个Dataset的子类将数据封装到PyTorch框架可识别数据结构。
上述的使用可以参考:bert-base-chinese训练

数据集构造二

这种方式通常是用在测试数据集合的时候,相对训练时的数据集构造可能比较麻烦一点。

# 自定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, path):
        df = pd.read_csv(path, encoding="utf-8")
        self.texts = df["content"].tolist()
        self.labels = df["punish_result"].tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        text = str(self.texts[i])
        label = int(self.labels[i])
        return text, label

path = "./data/abuse.csv"
dataset = Dataset(path)

# for i in dataset:
#     print(i) # ('喂喂王老板点开免费,我点开就扣钱',1)
#     print(type(i)) # <class 'tuple'>
#     break

# len(dataset), dataset[0]
# # (17182,('喂喂王老板点开免费,我点开就扣钱',1))

加载词典和分词器

# 加载字典和分词工具
token = BertTokenizer.from_pretrained("./bert-base-chinese")

设置辅助函数

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    """
    batch_text_or_text_pairs:
    类型: 列表或元组的列表。
    含义: 输入的文本数据,可以是单个文本列表(如果只处理单个句子)或配对的文本(如对话或翻译任务中的源语言和目标语言句子)。
    truncation:
    类型: 布尔值。
    含义: 是否对超过最大长度的文本进行截断。设置为 True 表示会截断超出长度限制的文本。
    padding:
    类型: 字符串。
    含义: 决定如何填充短于最大长度的文本。'max_length' 表示所有样本都会被填充到max_length的长度,以确保批次内的所有元素长度一致。
    max_length:
    类型: 整数。
    含义: 设定的最大序列长度。所有输入的文本将会被截断或填充到这个长度。
    return_tensors:
    类型: 字符串。
    含义: 指定返回的张量类型。'pt' 表示返回 PyTorch 张量,其他可能的选项有 'tf'(TensorFlow 张量)或 'np'(NumPy 数组)。
    return_length:
    类型: 布尔值。
    含义: 如果设置为 True,函数还会返回一个列表,其中包含每个输入文本的原始长度,这对于知道哪些部分是填充的很有用。
    """
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        truncation=True,
        padding="max_length",
        max_length=500,
        return_tensors="pt",
        return_length=True,
    )

    # input_ids: 编码之后的数字
    input_ids = data["input_ids"]

    # attention_mask是一个与输入tokens相同形状的二维数组
    # 1 表示有效的位置,即非填充的tokens。这些位置在计算注意力分数时会被考虑。
    # 0 表示填充的位置,模型在计算注意力时会忽略这些位置。
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    labels = torch.LongTensor(labels)

    # print(data['length'], data['length'].max())
    # tensor([ 56,  71,  32, 159,  34, 179,  33,  79,  49,  33,  98,  89, 212,  41,
    #      63,  58]) tensor(212)

    return input_ids, attention_mask, token_type_ids, labels

加载数据集

"""
dataset:
类型: torch.utils.data.Dataset 的实例。
含义: 指定要加载的数据集。dataset 参数接收之前定义的 TextDataset 实例,包含了预处理过的文本数据和标签。
batch_size:
类型: 整数。
含义: 每个批次(batch)中的样本数量。在这个例子中,设置为 16,意味着数据加载器每次返回的将是包含16个样本的数据批次,用于模型训练或评估。
collate_fn:
类型: 可调用对象(如函数)。
含义: 用于整理一个批次的数据。当从数据集中取出多个样本时,collate_fn 会被调用来将这些样本打包成一个批次。这对于处理变长序列(如文本)特别有用,因为需要对不同长度的序列进行填充或截断以适应批处理。如果没有提供,默认的 collate_fn 可能不适用于所有情况,特别是当数据具有复杂结构时。
shuffle:
类型: 布尔值。
含义: 是否在每个 epoch 开始时对数据集进行随机洗牌。设置为 True 表示在训练过程中数据会随机排序,有助于提高模型的泛化能力。对于验证或测试集,通常应设为 False。
drop_last:
类型: 布尔值。
含义: 如果设置为 True,在最后一个批次不足以填满整个 batch_size 时,这个批次将会被丢弃。如果设为 False,则最后一个批次可能包含少于 batch_size 的样本数量。这在某些模型训练中是有用的,尤其是当模型设计要求固定的批次大小时。
"""
loader = torch.utils.data.DataLoader(
    dataset=dataset, 
    batch_size=16, 
    collate_fn=collate_fn, 
    shuffle=True, 
    drop_last=True
)

# for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
#     print(i, input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
#     # 0 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16])
#     break
  • 8
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值