昇思MindSpore 应用学习-基于 MindSpore 实现 BERT 对话情绪识别

基于 MindSpore 实现 BERT 对话情绪识别

模型简介

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。
BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。
在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。
因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。
BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。
对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。
下面以一个文本情感分类任务为例子来说明BERT模型的整个应用过程。

import os  # 导入os模块,用于与操作系统交互

import mindspore  # 导入MindSpore深度学习框架
from mindspore.dataset import text, GeneratorDataset, transforms  # 从mindspore.dataset导入文本处理和数据集相关的功能
from mindspore import nn, context  # 从mindspore导入神经网络模块和上下文管理器

from mindnlp._legacy.engine import Trainer, Evaluator  # 导入MindNLP的训练和评估模块
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback  # 导入用于模型保存和最佳模型回调的功能
from mindnlp._legacy.metrics import Accuracy  # 导入准确率评估指标

代码解析

  1. 导入模块
    • import os:引入操作系统相关功能,通常用于文件和路径操作。
    • import mindspore:引入MindSpore框架,提供深度学习的基础设施。
    • from mindspore.dataset import text, GeneratorDataset, transforms
      • text:处理文本数据的功能模块。
      • GeneratorDataset:可以通过生成器动态生成数据集。
      • transforms:用于数据预处理和转换的工具。
    • from mindspore import nn, context
      • nn:包含神经网络构建所需的各类模块和层。
      • context:用于设置MindSpore的运行环境和上下文配置。
    • from mindnlp._legacy.engine import Trainer, Evaluator
      • Trainer:用于模型的训练过程管理。
      • Evaluator:用于评估模型性能的工具。
    • from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
      • CheckpointCallback:用于在训练过程中保存模型的回调。
      • BestModelCallback:用于记录最佳模型的回调。
    • from mindnlp._legacy.metrics import Accuracy:引入准确率计算的指标模块。

API 解析

  • mindspore.dataset
    • 该模块提供了用于处理数据集的工具,支持文本数据的加载和转换。
  • mindspore.nn
    • 提供构建神经网络的基础组件,如层、损失函数等。
  • mindspore.context
    • 用于设置运行环境,例如选择计算设备(CPU/GPU/Ascend)。
  • mindnlp._legacy.engine
    • 提供训练和评估框架,可以简化模型的训练流程。
  • mindnlp._legacy.engine.callbacks
    • 提供回调机制,帮助用户在训练过程中实现模型保存、学习率调整等功能。
  • mindnlp._legacy.metrics
    • 提供性能评估指标,如准确率等,帮助在训练和评估阶段监测模型表现。

Building prefix dict from the default dictionary … Loading model from cache /tmp/jieba.cache Loading model cost 1.019 seconds. Prefix dict has been built successfully.

# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path  # 存储数据集的路径
        self._labels, self._text_a = [], []  # 初始化标签和文本列表
        self._load()  # 调用加载数据集的方法

    def _load(self):
        # 从指定路径加载数据集
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()  # 读取数据集文件内容
        lines = dataset.split("\n")  # 按行分割数据
        for line in lines[1:-1]:  # 遍历每一行,跳过第一行和最后一行
            label, text_a = line.split("\t")  # 按制表符分割标签和文本
            self._labels.append(int(label))  # 将标签转换为整数并添加到标签列表
            self._text_a.append(text_a)  # 将文本添加到文本列表

    def __getitem__(self, index):
        # 根据索引返回标签和文本
        return self._labels[index], self._text_a[index]

    def __len__(self):
        # 返回数据集的大小
        return len(self._labels)

代码解析

  1. 类定义
    • class SentimentDataset:定义一个用于情感分析的数据集类,主要用于加载和提供数据。
  2. 初始化方法
    • def __init__(self, path)
      • self.path = path:将数据集文件路径存储到实例变量中。
      • self._labels, self._text_a = [], []:初始化两个空列表,用于存储标签和文本数据。
      • self._load():调用私有方法 _load 来加载数据。
  3. 数据加载方法
    • def _load(self)
      • with open(self.path, "r", encoding="utf-8") as f:以只读模式打开数据集文件,指定编码为UTF-8。
      • dataset = f.read():读取文件内容。
      • lines = dataset.split("\n"):将文件内容按行切割。
      • for line in lines[1:-1]:循环遍历每行数据,跳过第一行(通常是表头)和最后一行(可能为空行)。
        • label, text_a = line.split("\t"):将每行数据按制表符分割,获取标签和文本。
        • self._labels.append(int(label)):将标签转换为整数并添加到 _labels 列表。
        • self._text_a.append(text_a):将文本添加到 _text_a 列表。
  4. 索引获取方法
    • def __getitem__(self, index):根据给定的索引返回对应的标签和文本。
      • return self._labels[index], self._text_a[index]:返回标签和文本的元组。
  5. 长度获取方法
    • def __len__(self):返回数据集中样本的数量。
      • return len(self._labels):返回标签列表的长度。

API 解析

  • __init__:构造函数,用于初始化类的实例,并设置初始状态。
  • 文件读取:使用Python内置的 open 函数来读取文件内容,通过 with 语句确保文件在使用后正确关闭。
  • 列表操作:使用列表的 append 方法动态添加数据。
  • __getitem__** 和 **__len__:这两个方法是Python的数据模型方法,允许类的实例像列表一样被索引和测量长度,使得 SentimentDataset 类可以很方便地与其他数据处理库(如PyTorch或MindSpore)配合使用。

数据集

这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符(‘\t’)分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。
label–text_a
0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1–我有事等会儿就回来和你聊
2–我见到你很高兴谢谢你帮我
这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
# 使用wget命令从指定URL下载情感检测数据集并保存为emotion_detection.tar.gz

!tar xvf emotion_detection.tar.gz
# 解压下载的tar.gz文件,提取内容

代码解析

  1. 下载数据集
    • !wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
      • !wget:使用shell命令 wget 下载文件。
      • https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz:指定要下载的数据集的URL。
      • -O emotion_detection.tar.gz:将下载的文件保存为 emotion_detection.tar.gz
  2. 解压文件
    • !tar xvf emotion_detection.tar.gz
      • !tar:使用shell命令 tar 来处理归档文件。
      • xvf:这是 tar 命令的选项:
        • x:表示解压缩。
        • v:表示显示解压缩过程中的文件(verbose模式)。
        • f:表示后面跟的是文件名。
      • emotion_detection.tar.gz:要解压的文件名。

API 解析

  • wget
    • 一个用于从网络下载文件的命令行工具,支持 HTTP、HTTPS 和 FTP 协议。
  • tar
    • 用于打包和解压缩文件的命令行工具,常用于Linux和Unix系统。.tar.gz格式是经过gzip压缩的tar档案,结合了两种工具的优点。

注意事项

  • 在执行上述命令时,确保你的环境支持 ! 前缀的shell命令,这通常在Jupyter Notebook或某些支持魔法命令的环境中有效。
  • 下载和解压的操作需要网络连接,并且保存路径需要有写入权限。

数据加载和数据预处理

新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。

import numpy as np  # 导入NumPy库,用于数值计算和操作

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    # 获取当前设备目标,判断是否为Ascend
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]  # 定义数据集的列名
    
    # 创建生成器数据集
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    
    # 定义类型转换操作
    type_cast_op = transforms.TypeCast(mindspore.int32)

    def tokenize_and_pad(text):
        # 根据设备类型进行分词和填充操作
        if is_ascend:
            # 在Ascend上进行填充和截断
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            # 在非Ascend设备上只进行分词
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']  # 返回输入ID和注意力掩码

    # 数据集映射操作,应用分词和填充函数
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    
    # 为标签应用类型转换
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    
    # 批处理数据集
    if is_ascend:
        dataset = dataset.batch(batch_size)  # 在Ascend上使用常规批处理
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})  # 在非Ascend上使用填充批处理

    return dataset  # 返回处理后的数据集

代码解析

  1. 导入
    • import numpy as np:导入NumPy库,通常用于数值运算,但在此代码中未直接使用。
  2. 函数定义
    • def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True)
      • source:数据源,可以是文本文件、数据集对象等。
      • tokenizer:用于将文本转换为模型输入格式的分词器。
      • max_seq_len:设置文本的最大序列长度,超过该长度的文本将被截断。
      • batch_size:每个批次的样本数量。
      • shuffle:是否在生成数据集时打乱数据顺序。
  3. 设备判断
    • is_ascend = mindspore.get_context('device_target') == 'Ascend':判断当前执行环境是否为Ascend设备。
  4. 数据集创建
    • column_names = ["label", "text_a"]:定义数据集中包含的列名。
    • dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle):创建一个生成器数据集。
  5. 类型转换操作
    • type_cast_op = transforms.TypeCast(mindspore.int32):创建一个类型转换操作,将标签转换为32位整数。
  6. 分词和填充
    • def tokenize_and_pad(text):定义一个内部函数用于对输入文本进行分词和填充。
      • tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len):在Ascend设备上进行分词,填充到最大长度。
      • 返回分词后的 input_idsattention_mask
  7. 数据集映射
    • dataset = dataset.map(...):将文本列应用分词和填充操作,同时将标签列应用类型转换操作。
  8. 批处理
    • if is_ascend:根据设备类型选择合适的批处理方式。
      • dataset.batch(batch_size):在Ascend设备上使用常规批处理。
      • dataset.padded_batch(batch_size, pad_info=...):在其他设备上使用填充批处理,指定填充值。
  9. 返回数据集
    • return dataset:返回处理后的数据集对象。

API 解析

  • GeneratorDataset:MindSpore中的一个数据集类,允许用户通过生成器动态生成数据。
  • map:数据集的映射方法,可以对每个样本应用给定的操作。
  • TypeCast:用于将数据类型转换为指定类型的操作。
  • batch / padded_batch:用于将数据集分成批次,支持标准批处理和填充批处理,以处理不同长度的输入。

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:

from mindnlp.transformers import BertTokenizer  # 导入BertTokenizer类,用于加载和使用BERT分词器

# 从预训练的BERT模型加载分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')   

# 获取分词器的填充标记ID
tokenizer.pad_token_id  

# 创建训练数据集,使用自定义的SentimentDataset类和分词器
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)

# 创建验证数据集
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)

# 创建测试数据集,禁用打乱
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)

# 获取训练数据集的列名称
dataset_train.get_col_names()  

# 从训练数据集中获取一个迭代器并打印下一个样本
print(next(dataset_train.create_tuple_iterator()))  

代码解析

  1. 导入分词器
    • from mindnlp.transformers import BertTokenizer:从MindNLP库导入BERT分词器的类。
  2. 加载预训练的BERT分词器
    • tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
      • 使用指定的模型名称加载预训练的BERT分词器,这里是中文BERT模型。
  3. 获取填充标记ID
    • tokenizer.pad_token_id:获取分词器的填充标记的ID,用于后续的填充操作。
  4. 创建数据集
    • dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
      • 创建训练数据集,使用自定义的 SentimentDataset 类,将数据集路径和分词器传入。
    • dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
      • 创建验证数据集。
    • dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
      • 创建测试数据集,并设置 shuffle=False 以保持数据顺序。
  5. 获取列名称
    • dataset_train.get_col_names():调用方法获取训练数据集中列的名称,通常是标签和文本列。
  6. 创建迭代器并打印样本
    • print(next(dataset_train.create_tuple_iterator()))
      • 创建一个迭代器,使用 next() 获取下一个样本并打印出来,通常返回的是一个元组,包含标签和分词后的文本。

API 解析

  • BertTokenizer
    • 用于处理BERT模型的文本输入,负责将文本转换为模型可以接受的ID格式,并执行必要的填充和截断。
  • from_pretrained
    • 类方法,用于加载预训练模型的分词器,支持多种语言和任务。
  • pad_token_id
    • 分词器的填充标记ID,通常用于处理不同长度的输入,确保输入形状一致。
  • SentimentDataset
    • 自定义的数据集类,用于从指定文件加载情感分析相关的数据。
  • process_dataset
    • 处理数据集的函数,执行分词、填充和批处理等操作,返回已处理的数据集。
  • create_tuple_iterator
    • 数据集的方法,用于创建一个迭代器,可以返回数据集中的样本。
  • next()
    • Python内置函数,用于获取迭代器的下一个值。

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel  # 导入BERT模型和用于序列分类的特定模型
from mindnlp._legacy.amp import auto_mixed_precision  # 导入自动混合精度的工具

# 设置BERT配置并定义训练参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)  
# 从预训练的BERT模型加载序列分类模型,设置输出标签数量为3

model = auto_mixed_precision(model, 'O1')  
# 应用自动混合精度以提高训练性能,'O1'为混合精度的优化策略

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)  
# 创建Adam优化器,设置学习率为2e-5,并将模型可训练参数作为优化目标

metric = Accuracy()  # 定义准确率作为评估指标

# 定义回调函数以保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)  
# 创建检查点回调,设置保存路径、检查点名称、保存频率和最大保留检查点数量

best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)  
# 创建最优模型回调,自动加载最佳模型

# 创建训练器
trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])  

%%time  # 记录训练时间

# 开始训练
trainer.run(tgt_columns="labels")  

代码解析

  1. 导入必要的库
    • from mindnlp.transformers import BertForSequenceClassification, BertModel:导入MindNLP中的BERT序列分类模型。
    • from mindnlp._legacy.amp import auto_mixed_precision:导入混合精度训练工具,以优化模型训练的性能和内存使用。
  2. 模型创建
    • model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
      • 从预训练的中文BERT模型加载一个用于序列分类的模型,设置标签数量为3(例如,情感分析中的三种情感)。
  3. 混合精度训练
    • model = auto_mixed_precision(model, 'O1')
      • 应用自动混合精度以节省内存和加速训练,其中 ‘O1’ 是适用于混合精度训练的优化策略。
  4. 优化器定义
    • optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
      • 使用Adam优化器,学习率设为2e-5,优化目标是模型的可训练参数。
  5. 指标定义
    • metric = Accuracy()
      • 定义准确率作为模型评估的指标。
  6. 定义回调
    • ckpoint_cb = CheckpointCallback(...)
      • 创建检查点回调,以便在训练过程中定期保存模型检查点,设置保存路径、名称和最大检查点数量。
    • best_model_cb = BestModelCallback(...)
      • 创建最佳模型回调,以自动加载保存的最佳模型。
  7. 训练器创建
    • trainer = Trainer(...)
      • 初始化训练器,传入模型、训练数据集、验证数据集、评估指标、训练周期、优化器和回调列表。
  8. 时间记录
    • %%time:在Jupyter Notebook中使用魔法命令记录代码块的执行时间。
  9. 开始训练
    • trainer.run(tgt_columns="labels")
      • 开始模型的训练过程,指定目标列为标签列。

API 解析

  • BertForSequenceClassification
    • BERT模型的变体,专门用于序列分类任务,能够处理文本分类任务的输入。
  • auto_mixed_precision
    • 用于自动应用混合精度训练,结合使用不同的数据类型以提高训练效率。
  • nn.Adam
    • Adam优化器,常用于深度学习中的参数优化。
  • Accuracy
    • 评估指标类,用于计算模型的准确率。
  • CheckpointCallback
    • 回调类,用于在训练过程中保存模型检查点。
  • BestModelCallback
    • 回调类,用于自动保存和加载最佳模型。
  • Trainer
    • MindSpore中用于管理整个训练过程的类,负责模型训练和评估的实施。
  • run
    • 启动训练过程的方法,传入训练所需的参数。

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)  
# 创建评估器,传入训练好的模型、测试数据集和评估指标

evaluator.run(tgt_columns="labels")  
# 运行评估,指定目标列为标签列

代码解析

  1. 创建评估器
    • evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
      • 使用训练好的模型、测试数据集和指定的评估指标初始化评估器。
      • network=model:传入已经训练好的模型。
      • eval_dataset=dataset_test:传入要评估的测试数据集。
      • metrics=metric:传入评估所使用的指标(如准确率)。
  2. 运行评估
    • evaluator.run(tgt_columns="labels")
      • 调用评估器的 run 方法开始评估过程,使用 tgt_columns="labels" 指定要评估的目标列为标签列。

API 解析

  • Evaluator
    • 用于评估模型性能的类,能够帮助用户在特定数据集上计算和输出模型的评估指标。
  • run
    • 方法用于执行评估过程,通常会输出评估结果和性能指标,比如准确率、F1分数等。
  • tgt_columns
    • 指定在评估过程中需要关注的标签列,通常是模型预测的目标列。

模型推理

遍历推理数据集,将结果与标签进行统一展示。

# 加载待预测数据集
dataset_infer = SentimentDataset("data/infer.tsv")  

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}  # 定义标签映射

    # 将输入文本进行分词并转换为Tensor格式
    text_tokenized = Tensor([tokenizer(text).input_ids])  
    # 使用模型进行预测
    logits = model(text_tokenized)  
    # 获取预测标签(取最大值的索引作为预测结果)
    predict_label = logits[0].asnumpy().argmax()  
    # 格式化输出信息
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"  
    if label is not None:
        info += f" , label: '{label_map[label]}'"  # 如果提供真实标签,则输出真实标签
    print(info)  # 打印信息

# 遍历待预测数据集并进行预测
for label, text in dataset_infer:
    predict(text, label)  

代码解析

  1. 加载待预测数据集
    • dataset_infer = SentimentDataset("data/infer.tsv")
      • 从指定的文件路径加载待预测的数据集,这里是 infer.tsv 文件。
  2. 定义预测函数
    • def predict(text, label=None):
      • 定义一个 predict 函数,接受文本输入和可选的真实标签。
  3. 标签映射
    • label_map = {0: "消极", 1: "中性", 2: "积极"}
      • 创建一个字典,将数值标签映射到对应的情感描述。
  4. 文本分词和转换
    • text_tokenized = Tensor([tokenizer(text).input_ids])
      • 使用提前定义的分词器对输入文本进行分词,并将生成的ID转换成Tensor格式,以便输入到模型中。
  5. 模型预测
    • logits = model(text_tokenized)
      • 将分词后的文本输入到已训练的模型中获取预测结果(logits)。
  6. 获取预测标签
    • predict_label = logits[0].asnumpy().argmax()
      • 从模型输出的logits中获取预测标签,通过取最大值的索引(argmax())来确定最可能的情感类别。
  7. 格式化输出信息
    • info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
      • 创建一个包含输入文本和预测结果的字符串。
    • if label is not None: 语句用于检查是否提供了真实标签,如果有,则将真实标签添加到输出信息中。
  8. 打印信息
    • print(info):输出格式化的信息到控制台。
  9. 遍历数据集并进行预测
    • for label, text in dataset_infer:
      • 遍历待预测的数据集 dataset_infer,对于每一对 (label, text),调用 predict 函数进行预测。

API 解析

  • SentimentDataset
    • 自定义的数据集类,用于加载情感分析任务的输入数据。
  • Tensor
    • MindSpore中的数据结构,用于存储和处理多维数据,尤其是在深度学习中作为输入或输出。
  • tokenizer
    • 分词器实例,负责将文本转换为模型可以接受的ID格式。
  • model
    • 已训练的情感分类模型,用于对输入文本生成预测结果。
  • logits
    • 模型的输出,通常是每个类别的未归一化的得分,使用 argmax() 获取预测标签。
  • argmax()
    • NumPy中的函数,用于返回数组中最大值的索引,在此用来确定最可能的情感类别。

自定义推理数据集

自己输入推理数据,展示模型的泛化能力。

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")
  • 11
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值