栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

QA的评价指标MAP、MRR、Accuracy@N

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

QA的评价指标MAP、MRR、Accuracy@N

问答系统性能的评价指标MAP、MRR、Accuracy@N

MAP(mean average precision)

即平均准确率,系统对所有候选答案进行评分,并按分值大小进行排序,正确答案越靠前,MAP值就越大
计算方式如下:

参考:https://www.jianshu.com/p/e1664861bc9d

比如共有三个问题,问题1有3个直接相关的答案,问题2有2个直接相关的答案,问题3有4个直接相关的答案。系统返回的答案中,问题1的3个答案的排序为1,3,5;问题2的2个答案的排序为2,3;问题3的4个答案的排序为1,2,4,6。
那么对于问题1,平均准确率为(1/1+2/3+3/5)/3=0.756
对问题2,平均准确率为(1/2+2/3)/2=0.583
对问题3,平均准确率为(1/1+2/2+3/4+4/6)/4=0.854
则MAP=(0.756+0.583+0.854)/3=0.731

MRR(Mean Reciprocal Rank)

即平均排序倒数,计算方式如下:

参考:https://www.jianshu.com/p/e1664861bc9d

比如前述的问题中,MRR=(1/1+1/2+1/1)/3=0.833

Accuracy@N

即topN准确率,计算方式如下:

参考:https://www.jianshu.com/p/e1664861bc9d
比如前述问题中,设N=1,则Accuracy@1=(1+0+1)/3=0.667

代码实现

参考:https://github.com/shuaihuaiyi/QA/blob/master/taevaluation.py

'''
qIndex2aIndex2aScore: {qIndex:{aIndex:score,...},......}
qIndex2aIndex2aLabel: {qIndex:{aIndex:label,...},......}
'''

def calculate(qIndex2aIndex2aScore,qIndex2aIndex2aLabel):
	ACC_at1List = []
	APlist = []
	RRlist = []
    for qIndex, index2scoreList in qIndex2aIndex2aScore.items():     # 对每一个问题
        index2label = qIndex2aIndex2aLabel[qIndex]     # {aindex:label,......}
        rankIndex = 0
        rightNum = 0
        curPList = []
        rankedList = sorted(index2scoreList.items(), key=lambda b: b[1], reverse=True)     # [(aindex,score),......]
        ACC_at1List.append(0)
        for info in rankedList:    # 对每一个答案
            aIndex = info[0]
            label = index2label[aIndex]
            rankIndex += 1      # 第几个答案
            if label == 1:      # 如果是正确答案
                rightNum += 1   # 正确答案数+1
                if rankIndex == 1:  # 如果是排序第一的答案
                    ACC_at1List[-1] = 1    # ACC@1
                p = float(rightNum) / rankIndex
                curPList.append(p)
        if len(curPList) > 0:
            RRlist.append(curPList[0])
            APlist.append(float(sum(curPList)) / len(curPList))
    return ACC_at1List,APlist,RRlist



def MRR(RRlist):
    return float(sum(RRlist)) / len(RRlist)


def MAP(APlist):
    return float(sum(APlist)) / len(APlist)


def ACC_at_1(ACC_at1List):
    return float(sum(ACC_at1List)) / len(ACC_at1List)

参考:https://www.jianshu.com/p/e1664861bc9d
https://blog.csdn.net/lightty/article/details/47079017
https://github.com/shuaihuaiyi/QA/blob/master/taevaluation.py

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/275905.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号