训练通用教程-huggingface-datasets

训练通用教程-huggingface-datasets

使用datasets构建好数据集后,使用datasets.map方法进行训练。

1. 构建数据集

利用datasets.Datasets.load_from_pandas构建数据集。

import torch
from datasets import Dataset, DatasetDict
import json
import pandas as pd
from config import config

'''
build_datasets: 使用Dataset.from_pandas构建数据集。
load_datasets: 返回训练或者测试数据集。
label_to_id: pandas.map使用的转换函数。
save_to_disk: 使用DatasetDict.save_to_disk函数保存到磁盘。

'''

class Datasets:
    def __init__(self):
        self.train_datasets = self.build_datasets(config.train_data_file)
        self.valid_datasets = self.build_datasets(config.valid_data_file)
        self.datasets_dict = DatasetDict({'train': self.train_datasets, 'valid': self.valid_datasets})

    def build_datasets(self, file_path):
        # 保存token序列
        text_list = []
        # 保存label序列
        label_list = []
        datasets = pd.DataFrame()
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            for line in lines:
                line_to_json = json.loads(line)
                '''
                text_list: ['腹 痛 就 是 可 治愈 的', '头 痛 也 是']
                '''
                text_list.append(' '.join(line_to_json['text']))
                '''
                label_list = [['B', 'E', 'O', 'O', 'O'], ['B', 'E', 'O', 'O']]
                '''
                label_list.append(line_to_json['label'])
        datasets['text'] = text_list
        datasets['label'] = label_list
        # pandas.map转换:就是遍历该列的数据,并进行转换。
        datasets['label'] = datasets['label'].map(self.label_to_id)
        return Dataset.from_pandas(datasets)

    def load_datasets(self, datasets_name):
        return self.datasets_dict[datasets_name]

    def label_to_id(self, label_list):
        # label_list: ['O', 'O', 'O', 'O', 'O', 'B-sym', 'I-sym']
        return [config.tags_to_index[item] for item in label_list]

    def save_to_disk(self, save_path):
        self.datasets_dict.save_to_disk(save_path)


2. 训练通用流程
import os
import sys
import torch
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from config import config
from datasets_utils.load_datasets import Datasets
from models.ner import NER
import torch.optim as optim
'''
batchward(): 反向传播
train_self: datasets.map方法中传入的函数。
train: 使用datasets.map进行训练。
'''

class Train:
    def __init__(self, model_path=None):
        self.train_datasets = Datasets().load_datasets('train')
        self.model = NER() if not model_path else torch.load(model_path)
        self.optim = optim.AdamW(self.model.parameters(), lr=config.train_lr)
        self.loss = 0
        self.loss_record = 0.0
        self.loss_best = float('inf')


    def backward(self):
        self.optim.zero_grad()
        self.loss.backward()
        self.optim.step()
        self.loss_record += self.loss.item()

    def train_self(self, batch):
        '''
        batch: {'text':[], 'label':[]}
        text: ["", "", ""...], len(text)=batch_size
        label: [[],[],[]...], len(label)=batch_size
        '''
        # 模型中最后一层CRF充当了损失函数,model的__call__方法包含损失函数。
        self.loss = self.model(batch['text'], batch['label'])
        self.backward()


    def train(self, epoches, model_save_path):
        for epoch in range(1, epoches + 1):
            self.loss_record = 0.0
            self.train_datasets.map(self.train_self,
                                    batched=True,
                                    batch_size=config.train_batch_size,
                                    desc=f'epoch: {epoch}')
            print(f'epoch: {epoch} loss: {self.loss_record}')
            if self.loss_record < self.loss_best:
                torch.save(self.model, model_save_path)
                self.loss_best = self.loss_record


if __name__ == '__main__':
    with torch.device(0):
        model_save_path = os.path.join(config.train_ner_model_save_path, f'ner_best.bin')
        ner_train = Train('../models_saved/ner/ner_best.bin')
        # ner_train.train(config.num_epoch)
        ner_train.train(1, model_save_path)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值