使用Python评估文字生成模型的详细步骤

支持BLEU(1~4)、METEOR、ROUGE、CIDEr、SPICE、WMD六种评价指标的计算!

1. 代码地址

开源地址下载:https://github.com/ruotianluo/coco-caption

git clone https://github.com/ruotianluo/coco-caption.git

2. 下载spice需要的支持依赖

(1)安装model

bash get_stanford_models.sh

自动下载和解压。

(2) 安装java

spice的运行需要java支持,否则会报错:

FileNotFoundError: [Errno 2] No such file or directory: 'java'
  1. java的下载地址:https://www.oracle.com/java/technologies/downloads/#java8

    注意:这里要安装java1.8版本,不然会报错的!

    这里需要注册一个Oracle账号,亲测除了邮箱要填对(因为会发验证邮箱),其他乱填即可。

    因为我在linux下,所以选择:jdk-8u311_linux-x64_bin.tar.gz

  2. 下载好后,进入下载的地址,把文件copy到/opt下:

    sudo cp Downloads/jdk-8u311_linux-x64_bin.tar.gz /opt
    
  3. 给自己开权限:

    cd /opt
    sudo mkdir java
    sudo chown [user_name] java
    sudo chgrp [user_name] java
    

    注意,这里[user_name]是你linux的账户名。

  4. 解压:

    sudo tar -zxvf jdk-8u311_linux-x64_bin.tar.gz -C /opt/java
    
  5. 配置环境变量:

    sudo gedit /etc/profile
    

    如果是无屏幕界面,gedit换成vim即可。

  6. 追加如下信息:

    #set java environment
    export JAVA_HOME=/opt/java/jdk1.8.0_311
    export PATH=${JAVA_HOME}/bin:${PATH}
    

    保存退出后,更新一下:

    source /etc/profile
    

    但是这里我虽然当前terminal是可以显示java的版本的,但是新开一个terminal就显示不了java的版本了,还是路径索引没设置好,所以我在zshrc里也设置了一下:

    (注意!如果使用的是bash请用bashrc)

    sudo gedit ~/.zshrc
    

    把刚才添加的路径信息同样加在文件后面,然后保存退出,再更新一下:

    source ~/.zshrc
    

    ok,现在就可以正常找到java啦。

  7. 查看java是否安装成功:

    java -version
    

3. 下载WMD需要的库

bash get_google_word2vec_model.sh

4. DEMO

COCO-CAPTION里没有附demo,py文件,我给自己写了个,直接执行应该就可以啦:

# -*- coding=utf-8 -*-
# author: w61
# Test for several ways to compute the score of the generated words.
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.wmd.wmd import WMD

class Scorer():
    def __init__(self,ref,gt):
        self.ref = ref
        self.gt = gt
        print('setting up scorers...')
        self.scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr"),
            (Spice(), "SPICE"),
            (WMD(),   "WMD"),
        ]
    
    def compute_scores(self):
        total_scores = {}
        for scorer, method in self.scorers:
            print('computing %s score...'%(scorer.method()))
            score, scores = scorer.compute_score(self.gt, self.ref)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    print("%s: %0.3f"%(m, sc))
                total_scores["Bleu"] = score
            else:
                print("%s: %0.3f"%(method, score))
                total_scores[method] = score
        
        print('*****DONE*****')
        for key,value in total_scores.items():
            print('{}:{}'.format(key,value))

if __name__ == '__main__':
    ref = {
        '1':['go down the stairs and stop at the bottom .'],
        '2':['this is a cat.']
    }
    gt = {
        '1':['Walk down the steps and stop at the bottom. ', 'Go down the stairs and wait at the bottom.','Once at the top of the stairway, walk down the spiral staircase all the way to the bottom floor. Once you have left the stairs you are in a foyer and that indicates you are at your destination.'],
        '2':['It is a cat.','There is a cat over there.','cat over there.']
    }
    # 注意,这里如果只有一个sample,cider算出来会是0,详情请看评论区。
    scorer = Scorer(ref,gt)
    scorer.compute_scores()
  • 3
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值