小黑啃fastNLP:自定义metrics解决分类问题

1.继承ClassfierMetric

from fastNLP.core.metrics import MetricBase
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
class ClassfierMetric(MetricBase):
    
    def __init__(self,pred = None,target = None,seq_len = None):
        """
        :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
        :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
        :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
        """
        super().__init__()
        self._init_param_map(pred = pred,target = target,seq_len = seq_len)
        
        self.total = 0
        self.total_pred = []
        self.total_target = []
    
    def evaluate(self,pred,target,seq_len = None):
        """
        evaluate函数将针对一个批次的预测结果做评价指标的累计

        :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
                torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
        :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
                torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
        :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
                如果mask也被传进来的话seq_len会被忽略.
        """
        # pred与target的形式多种多样
        if not isinstance(pred, torch.Tensor):
            raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(pred)}.")
        if not isinstance(target, torch.Tensor):
            raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(target)}.")
        if seq_len is not None and not isinstance(seq_len,torch.Tensor):
            raise TypeError(f"'seq_lens' in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(seq_len)}."
                           )
        
        if seq_len is not None and target.dim() > 1:
            max_len = target.size(1)
            masks = seq_len_to_mask(seq_len = seq_len,max_len = max_len)
        else:
            masks = None
        
        if pred.dim() == target.dim():
            pass
        elif pred.dim() == target.dim() + 1:
            pred = pred.argmax(dim = -1)
            if seq_len is None and target.dim() > 1:
                warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
        else:
            raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
                               f"size:{pred.size()}, target should have size: {pred.size()} or "
                               f"{pred.size()[:-1]}, got {target.size()}.")
        target = target.to(pred)
        if masks is not None:
            self.total += torch.sum(masks).item()
        else:
            self.total += np.prod(list(pred.size()))
        
        
        self.total_pred.extend(pred.cpu().tolist())
        self.total_target.extend(target.cpu().tolist())
    def get_metric(self,reset = True):
        p = precision_score(self.total_target,self.total_pred,average='micro')
        r = recall_score(self.total_target,self.total_pred,average='micro')
        f = f1_score(self.total_target,self.total_pred,average='micro')
        print(classification_report(self.total_target,self.total_pred))
        return {'P':p,'R':r,'F':f}

2.核心代码

from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.io import ChnSentiCorpPipe, DataBundle
from tqdm import tqdm
from fastNLP.embeddings import StaticEmbedding, StackEmbedding,BertEmbedding,LSTMCharEmbedding
from fastNLP import Trainer, CrossEntropyLoss,BCELoss
import torch
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from fastNLP.models import BertForSequenceClassification,CNNText

def get_data(file):
    lines = [line for line in open(file, 'r', encoding='utf-8')]
    dataset = DataSet()

    for line in tqdm(lines):
        assert len(line.split('\t')) == 3
        # 读取内容
        text = line.split('\t')[1].strip()
        label = line.split('\t')[2].strip()

        # 构造fastNLP框架
        instance = Instance(words=list(text), target=label)
        dataset.append(instance)

    return dataset

def merge_data(train_set, test_set):
    target_vocab = Vocabulary(padding=None, unknown=None)
    target_vocab.from_dataset(train_set, test_set, field_name='target')
    target_vocab.index_dataset(train_set, test_set, field_name='target')

    char_vocab = Vocabulary()
    char_vocab.from_dataset(train_set, test_set, field_name='words')
    char_vocab.index_dataset(train_set, test_set, field_name='words')

    return train_set, test_set, target_vocab, char_vocab


def evaluate(test_set, model):
    probs = []
    labels = []
    for i in tqdm(range(len(test_set))):
        word_ids = test_set.words[i]
        word_ids = torch.LongTensor(word_ids)
        pred = model.predict(word_ids.view(1, -1))
        prob = pred['pred'].numpy()[0]
        target = test_set.target[i]
        probs.append(prob)
        labels.append(target)
    p = precision_score(labels, probs, average='micro')
    r = recall_score(labels, probs, average='micro')
    f1 = f1_score(labels, probs, average='micro')
    print(classification_report(labels, probs))
    print('P:', p)
    print('R:', r)
    print('F1:', f1)


train_set = get_data('./classify_data3/train.txt')
test_set = get_data('./classify_data3/test.txt')
train_set, test_set, target_vocab, char_vocab = merge_data(train_set, test_set)


fastnlp_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d',min_freq=2)
model_CNN = CNNText(fastnlp_embed, num_classes=27,dropout=0.1)
model_CNN.load_state_dict(torch.load('model_CNN.pth')['net'])
train_set.set_target('target')
train_set.set_input('words')

test_set.set_target('target')
test_set.set_input('words')


from fastNLP import Tester
tester = Tester(data=test_set,model = model_CNN,metrics=ClassfierMetric(),device='cpu')
tester.test()
输出:

607 out of 6280 words have frequency less than 2.
Found 5579 out of 5673 words in the pre-training embedding.
precision recall f1-score support

       0       0.53      0.78      0.63      4440
       1       0.50      0.48      0.49      3511
       2       0.62      0.57      0.59      3016
       3       0.66      0.57      0.61      2722
       4       0.56      0.56      0.56      2683
       5       0.65      0.63      0.64      1303
       6       0.54      0.62      0.58      1268
       7       0.64      0.76      0.70      1199
       8       0.87      0.60      0.71      1086
       9       0.88      0.67      0.76      1018
      10       0.58      0.43      0.49       896
      11       0.58      0.43      0.50       850
      12       0.70      0.61      0.65       496
      13       0.52      0.57      0.54       460
      14       0.48      0.08      0.14       365
      15       0.79      0.75      0.77       361
      16       0.55      0.55      0.55       352
      17       0.61      0.53      0.57       322
      18       0.69      0.44      0.54       295
      19       0.81      0.75      0.78       262
      20       0.64      0.35      0.45       133
      21       0.85      0.29      0.43        76
      22       0.85      0.62      0.72        37
      23       0.00      0.00      0.00        32
      24       0.20      0.08      0.11        25

accuracy                           0.59     27208

macro avg 0.61 0.51 0.54 27208
weighted avg 0.60 0.59 0.59 27208

Evaluate data in 5.79 seconds!
[tester]
ClassfierMetric: P=0.5936489267862394, R=0.5936489267862394, F=0.5936489267862394

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值