OpenKiwi的word-level评估指标(wmt19)

介绍

质量评估(Quality Estimation)的目的是在不使用参考翻译的情况下评估翻译系统的质量。由Unbabel团队提出的OpenKiwi是一个基于Pytorch的开源框架,该框架实现了WMT2015-18共享任务中最好的QE系统。

目的

然而,2019年评估指标与往年有些不同:往年评估指标为F1_mult、F1_OK、F1_BAD(words in MT、gaps in MT、words in SRC);而2019年评估指标为Target_F1、Target_MCC、Source_F1、Source_MCC。

这里列出F1_mult及MCC的计算方式:

p_ok = tp / (tp + fp)
p_bad = tn / (tn + fn)
r_ok = tp / (tp + fn)
r_bad = tn / (tn + fp)

f1_ok = 2 * p_ok * r_ok / (p_ok + r_ok)
f1_bad = 2 * p_bad * r_bad / (p_bad + r_bad)
F1_mult = f1_ok * f1_bad

MCC = (tp * tn - fp * fn) / (((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5)

代码修改

OpenKiwi-master/kiwi/lib/evaluate.py eval_word_level函数 line288, line298

def eval_word_level(golds, pred_files, tag_name):
    scores_table = []
    for pred_file, pred in pred_files[tag_name]:
        _check_lengths(golds[tag_name], pred)

        scores = score_word_level(
            list(flatten(golds[tag_name])), list(flatten(pred))
        )
        scores_table.append((pred_file, *scores))
    # If more than one system is provided, compute ensemble score
    if len(pred_files[tag_name]) > 1:
        ensemble_pred = _average(
            [list(flatten(pred)) for _, pred in pred_files[tag_name]]
        )
        ensemble_score = score_word_level(
            list(flatten(golds[tag_name])), ensemble_pred
        )
        scores_table.append(("*ensemble*", *ensemble_score))
    scores = np.array(
        scores_table,
        dtype=[
            ("File", "object"),
            ("F1_{}".format(const.LABELS[0]), float),
            ("F1_{}".format(const.LABELS[1]), float),
            ("F1_mult", float),
            ("MCC", float),
        ],
    )
    # Put the main metric in the first column
    scores = scores[
        [
            "File",
            "F1_mult",
            "F1_{}".format(const.LABELS[0]),
            "F1_{}".format(const.LABELS[1]),
            "MCC"
        ]
    ]
    return scores

OpenKiwi-master/kiwi/lib/evaluate.py print_scores_table函数line364, line369, line375

def print_scores_table(scores, prefix="TARGET"):
    prefix_path, scores["File"] = _extract_path_prefix(scores["File"])
    path_str = " ({})".format(prefix_path) if prefix_path else ""

    max_method_length = max(len(path_str) + 4, max(map(len, scores["File"])))
    print("-" * (max_method_length + 13 * 3))
    print("Word-level scores for {}:".format(prefix))
    print(
        "{:{width}}    {:9}    {:9}    {:9}    {:9}".format(
            "File{}".format(path_str),
            "F1_mult",
            "F1_{}".format(const.LABELS[0]),
            "F1_{}".format(const.LABELS[1]),
            "MCC",
            width=max_method_length,
        )
    )
    for score in np.sort(scores, order=["F1_mult", "File"])[::-1]:
        print(
            "{:{width}s}    {:<9.5f}    {:<9.5}    {:<9.5f}    {:<9.5f}".format(
                *score, width=max_method_length
            )
        )

OpenKiwi-master/kiwi/metrics/functions.py f1_scores函数line164, line167

def f1_scores(hat_y, y):
    """
    Return f1_bad, f1_ok and f1_product
    """
    p, r, f1, m, s = precision_recall_fscore_support(hat_y, y)
    f_mult = np.prod(f1)

    return (*f1, f_mult, m[0])


>p = [0.93228, 0.36475] p_ok, p_bad
>r = [0.88133, 0.51559] r_ok, r_bad
>f1 = [0.90609, 0.42725] f1_ok, f1_bad
>m = [0.34336, 0.34336] mcc
>s = [16122.0, 16122.0] support

OpenKiwi-master/kiwi/metrics/functions.py precision_recall_fscore_support函数line148

def precision_recall_fscore_support(hat_y, y, labels=None):
    n_classes = len(labels) if labels else None
    cnfm = confusion_matrix(hat_y, y, n_classes)

    if n_classes is None:
        n_classes = cnfm.shape[0]

    scores = np.zeros((n_classes, 5))
    for class_id in range(n_classes):
        scores[class_id] = scores_for_class(class_id, cnfm)
    return scores.T.tolist()


>cnfm = [[14965, 2015],   gold=0 
>        [1087, 1157]]    gold=1
>    predict=0  predict=1
>scores = [[0.93228, 0.88133, 0.90609, 0.34336, 16122.0],
>          [0.36475, 0.51559, 0.42725, 0.34336, 16122.0]]
>           p_ok     r_ok     f1_ok    mcc      support=tp+tn
>           p_bad    r_bad    f1_bad   mcc      support=tp+tn

OpenKiwi-master/kiwi/metrics/functions.py scores_for_class函数line137, line139

def scores_for_class(class_index, cnfm):
    tp = cnfm[class_index, class_index]
    fp = cnfm[:, class_index].sum() - tp
    fn = cnfm[class_index, :].sum() - tp
    tn = cnfm.sum() - tp - fp - fn

    p = precision(tp, fp, fn)
    r = recall(tp, fp, fn)
    f1 = fscore(tp, fp, fn)
    m = mcc(tp, tn, fp, fn)
    support = tp + tn
    return p, r, f1, m, support

OpenKiwi-master/kiwi/metrics/functions.py mcc函数line128

def mcc(tp, tn, fp, fn):
    if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) > 0:
        return (tp * tn - fp * fn) / (((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5)
    return 0

此外,还有两个地方需要修改:
openkiwi/kiwi/metrics/metrics.py summarize函数 line181

def summarize(self):
    summary = OrderedDict()
    _, _, f1, _, _ = precision_recall_fscore_support(self.Y_HAT, self.Y)
    summary[self.metric_name] = np.prod(f1)
    for i, label in enumerate(self.labels):
        summary['F1_' + label] = f1[i]
    return self._prefix_keys(summary)

openkiwi/kiwi/metrics/metrics.py summarize函数 line368

def summarize(self):
    summary = {}
    mid = len(self.Y) // 2
    if mid:
        perm = np.random.permutation(len(self.Y))
        self.Y = [self.Y[idx] for idx in perm]
        self.scores = [self.scores[idx] for idx in perm]
        m = MovingF1()
        fscore, threshold = m.choose(
            m.eval(self.scores[:mid], self.Y[:mid])
        )
        predictions = [
            const.BAD_ID if score >= threshold else const.OK_ID
            for score in self.scores[mid:]
        ]
        _, _, f1, _, _ = precision_recall_fscore_support(
            predictions, self.Y[mid:]
        )
        f1_mult = np.prod(f1)
        summary = {self.metric_name: f1_mult}
    return self._prefix_keys(summary)

运行结果

在这里插入图片描述
Giuhub链接

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
基于WMT'18 English-German数据集的Transformer网络的训练性能测试可以按照以下步骤进行: 1. 首先,下载Swin-Transformer代码并安装所需的依赖项。可以使用以下命令克隆代码库并安装依赖项: ``` git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection.git cd Swin-Transformer-Object-Detection pip install -r requirements.txt python setup.py develop ``` \[1\] 2. 接下来,准备测试代码。你可以参考提供的测试代码链接\[2\],根据你的需求进行修改和调整。 3. 运行测试代码。使用提供的测试代码对基于WMT'18 English-German数据集的Transformer网络进行性能测试。你可以使用以下命令运行测试代码: ``` OMP_NUM_THREADS=1 python tools/train.py configs/swin/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py ``` \[3\] 这样,你就可以进行基于WMT'18 English-German数据集的Transformer网络训练性能测试了。请注意,根据你的具体需求,你可能需要进行一些额外的配置和调整。 #### 引用[.reference_title] - *1* *2* [【目标检测】swin-transformer训练自己的数据集](https://blog.csdn.net/qq_44747572/article/details/127585299)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [利用swin-transformer训练自己的数据集](https://blog.csdn.net/Qingkaii/article/details/123332411)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值