支持BLEU(1~4)、METEOR、ROUGE、CIDEr、SPICE、WMD六种评价指标的计算!
1. 代码地址
开源地址下载:https://github.com/ruotianluo/coco-caption
git clone https://github.com/ruotianluo/coco-caption.git
2. 下载spice需要的支持依赖
(1)安装model
bash get_stanford_models.sh
自动下载和解压。
(2) 安装java
spice的运行需要java支持,否则会报错:
FileNotFoundError: [Errno 2] No such file or directory: 'java'
-
java的下载地址:https://www.oracle.com/java/technologies/downloads/#java8
注意:这里要安装java1.8版本,不然会报错的!
这里需要注册一个Oracle账号,亲测除了邮箱要填对(因为会发验证邮箱),其他乱填即可。
因为我在linux下,所以选择:
jdk-8u311_linux-x64_bin.tar.gz
-
下载好后,进入下载的地址,把文件copy到/opt下:
sudo cp Downloads/jdk-8u311_linux-x64_bin.tar.gz /opt
-
给自己开权限:
cd /opt sudo mkdir java sudo chown [user_name] java sudo chgrp [user_name] java
注意,这里[user_name]是你linux的账户名。
-
解压:
sudo tar -zxvf jdk-8u311_linux-x64_bin.tar.gz -C /opt/java
-
配置环境变量:
sudo gedit /etc/profile
如果是无屏幕界面,gedit换成vim即可。
-
追加如下信息:
#set java environment export JAVA_HOME=/opt/java/jdk1.8.0_311 export PATH=${JAVA_HOME}/bin:${PATH}
保存退出后,更新一下:
source /etc/profile
但是这里我虽然当前terminal是可以显示java的版本的,但是新开一个terminal就显示不了java的版本了,还是路径索引没设置好,所以我在zshrc里也设置了一下:
(注意!如果使用的是bash请用
bashrc
)sudo gedit ~/.zshrc
把刚才添加的路径信息同样加在文件后面,然后保存退出,再更新一下:
source ~/.zshrc
ok,现在就可以正常找到java啦。
-
查看java是否安装成功:
java -version
3. 下载WMD需要的库
bash get_google_word2vec_model.sh
4. DEMO
COCO-CAPTION里没有附demo,py文件,我给自己写了个,直接执行应该就可以啦:
# -*- coding=utf-8 -*-
# author: w61
# Test for several ways to compute the score of the generated words.
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.wmd.wmd import WMD
class Scorer():
def __init__(self,ref,gt):
self.ref = ref
self.gt = gt
print('setting up scorers...')
self.scorers = [
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
(Meteor(),"METEOR"),
(Rouge(), "ROUGE_L"),
(Cider(), "CIDEr"),
(Spice(), "SPICE"),
(WMD(), "WMD"),
]
def compute_scores(self):
total_scores = {}
for scorer, method in self.scorers:
print('computing %s score...'%(scorer.method()))
score, scores = scorer.compute_score(self.gt, self.ref)
if type(method) == list:
for sc, scs, m in zip(score, scores, method):
print("%s: %0.3f"%(m, sc))
total_scores["Bleu"] = score
else:
print("%s: %0.3f"%(method, score))
total_scores[method] = score
print('*****DONE*****')
for key,value in total_scores.items():
print('{}:{}'.format(key,value))
if __name__ == '__main__':
ref = {
'1':['go down the stairs and stop at the bottom .'],
'2':['this is a cat.']
}
gt = {
'1':['Walk down the steps and stop at the bottom. ', 'Go down the stairs and wait at the bottom.','Once at the top of the stairway, walk down the spiral staircase all the way to the bottom floor. Once you have left the stairs you are in a foyer and that indicates you are at your destination.'],
'2':['It is a cat.','There is a cat over there.','cat over there.']
}
# 注意,这里如果只有一个sample,cider算出来会是0,详情请看评论区。
scorer = Scorer(ref,gt)
scorer.compute_scores()