对于miulab.py报错提示list index out of range的修改方法

在学习意图检测时,运行SLU部分的代码总是一直报错list index out of range,找了很多原因都没有解决,代码及报错提示如下:

"""
Copy file (including metric) from MiuLab:

	https://github.com/MiuLab/SlotGated-SLU
"""
import os

# compute f1 score is modified from conlleval.pl
def __startOfChunk(prevTag, tag, prevTagType, tagType, chunkStart=False):
	if prevTag == 'B' and tag == 'B':
		chunkStart = True
	if prevTag == 'I' and tag == 'B':
		chunkStart = True
	if prevTag == 'O' and tag == 'B':
		chunkStart = True
	if prevTag == 'O' and tag == 'I':
		chunkStart = True

	if prevTag == 'E' and tag == 'E':
		chunkStart = True
	if prevTag == 'E' and tag == 'I':
		chunkStart = True
	if prevTag == 'O' and tag == 'E':
		chunkStart = True
	if prevTag == 'O' and tag == 'I':
		chunkStart = True

	if tag != 'O' and tag != '.' and prevTagType != tagType:
		chunkStart = True
	return chunkStart


def __endOfChunk(prevTag, tag, prevTagType, tagType, chunkEnd=False):
	if prevTag == 'B' and tag == 'B':
		chunkEnd = True
	if prevTag == 'B' and tag == 'O':
		chunkEnd = True
	if prevTag == 'I' and tag == 'B':
		chunkEnd = True
	if prevTag == 'I' and tag == 'O':
		chunkEnd = True

	if prevTag == 'E' and tag == 'E':
		chunkEnd = True
	if prevTag == 'E' and tag == 'I':
		chunkEnd = True
	if prevTag == 'E' and tag == 'O':
		chunkEnd = True
	if prevTag == 'I' and tag == 'O':
		chunkEnd = True

	if prevTag != 'O' and prevTag != '.' and prevTagType != tagType:
		chunkEnd = True
	return chunkEnd


def __splitTagType(tag):
	s = tag.split('-')
	if len(s) > 2 or len(s) == 0:
		raise ValueError('tag format wrong. it must be B-xxx.xxx')
	if len(s) == 1:
		tag = s[0]
		tagType = ""
	else:
		tag = s[0]
		tagType = s[1]
	return tag, tagType


def computeF1Score(ss, correct_slots, pred_slots, args):
	correctChunk = {}
	correctChunkCnt = 0.0
	foundCorrect = {}
	foundCorrectCnt = 0.0
	foundPred = {}
	foundPredCnt = 0.0
	correctTags = 0.0
	tokenCount = 0.0
	if ss is None:
		ss = [["UNK" for s in xx] for xx in correct_slots]
		ffile = "eval.txt"
	else:
		ffile = "eval_all.txt"
	with open(os.path.join(args.save_dir, ffile), "w", encoding="utf8") as writer:
		for correct_slot, pred_slot, tokens in zip(correct_slots, pred_slots, ss):
			inCorrect = False
			lastCorrectTag = 'O'
			lastCorrectType = ''
			lastPredTag = 'O'
			lastPredType = ''
			for c, p, token in zip(correct_slot, pred_slot, tokens):
				writer.writelines("{}\t{}\t{}\t{}\t{}\n".format(token, "n", "O", c, p))
				correctTag, correctType = __splitTagType(c)
				predTag, predType = __splitTagType(p)

				if inCorrect == True:
					if __endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
						__endOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
						(lastCorrectType == lastPredType):
						inCorrect = False
						correctChunkCnt += 1.0
						if lastCorrectType in correctChunk:
							correctChunk[lastCorrectType] += 1.0
						else:
							correctChunk[lastCorrectType] = 1.0
					elif __endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) != \
						__endOfChunk(lastPredTag, predTag, lastPredType, predType) or \
						(correctType != predType):
						inCorrect = False

				if __startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
					__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
					(correctType == predType):
					inCorrect = True

				if __startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True:
					foundCorrectCnt += 1
					if correctType in foundCorrect:
						foundCorrect[correctType] += 1.0
					else:
						foundCorrect[correctType] = 1.0

				if __startOfChunk(lastPredTag, predTag, lastPredType, predType) == True:
					foundPredCnt += 1.0
					if predType in foundPred:
						foundPred[predType] += 1.0
					else:
						foundPred[predType] = 1.0

				if correctTag == predTag and correctType == predType:
					correctTags += 1.0

				tokenCount += 1.0

				lastCorrectTag = correctTag
				lastCorrectType = correctType
				lastPredTag = predTag
				lastPredType = predType

			if inCorrect == True:
				correctChunkCnt += 1.0
				if lastCorrectType in correctChunk:
					correctChunk[lastCorrectType] += 1.0
				else:
					correctChunk[lastCorrectType] = 1.0

	if foundPredCnt > 0:
		precision = 1.0 * correctChunkCnt / foundPredCnt
	else:
		precision = 0
	if foundCorrectCnt > 0:
		recall = 1.0 * correctChunkCnt / foundCorrectCnt
	else:
		recall = 0
	if (precision + recall) > 0:
		f1 = (2.0 * precision * recall) / (precision + recall)
	else:
		f1 = 0
	out = os.popen('perl ./conlleval.pl -d \"\\t\" < {}'.format(os.path.join(args.save_dir, ffile))).readlines()
	f1 = float(out[1][out[1].find("FB1:") + 4:-1].replace(" ", "")) / 100

	return f1, precision, recall

运行后会提示list index out of range:

将报错信息中miulab.py中的out打印出来后显示是空列表[ ]

out = os.popen('perl ./conlleval.pl -d \"\\t\" < {}'.format(os.path.join(args.save_dir, ffile))).readlines()
print(out)
f1 = float(out[1][out[1].find("FB1:") + 4:-1].replace(" ", "")) / 100

再找到out这一行中的ffile所在位置,发现并没有将eval_all.txt中的数据读到out中

def computeF1Score(ss, correct_slots, pred_slots, args):
	correctChunk = {}
	correctChunkCnt = 0.0
	foundCorrect = {}
	foundCorrectCnt = 0.0
	foundPred = {}
	foundPredCnt = 0.0
	correctTags = 0.0
	tokenCount = 0.0
	if ss is None:
		ss = [["UNK" for s in xx] for xx in correct_slots]
		ffile = "eval.txt"
	else:
		ffile = "eval_all.txt"

于是我将ffile手动改为了‘eval_all.txt’,得以解决

如下

if foundPredCnt > 0:
		precision = 1.0 * correctChunkCnt / foundPredCnt
	else:
		precision = 0
	if foundCorrectCnt > 0:
		recall = 1.0 * correctChunkCnt / foundCorrectCnt
	else:
		recall = 0
	if (precision + recall) > 0:
		f1 = (2.0 * precision * recall) / (precision + recall)
	else:
		f1 = 0
	out = os.popen('perl ./conlleval.pl -d \"\\t\" < {}'.format(os.path.join(args.save_dir, ‘eval_all.txt’))).readlines()
	f1 = float(out[1][out[1].find("FB1:") + 4:-1].replace(" ", "")) / 100

	return f1, precision, recall

作者比较菜,代码能力比较弱,这个办法可以解决这个问题,但是没搞懂这样解决的原因,为什么out中没有存入信息,目录里生成了eval_all.txt文件,里面有内容,但是out里没有读下来,ffile这里有什么错误吗,希望各位大佬可以指点指点。

如果有其他有效的方法,欢迎讨论,互相学习!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值