栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

使用Python评估文字生成模型的详细步骤

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用Python评估文字生成模型的详细步骤

支持BLEU(1~4)、METEOR、ROUGE、CIDEr、SPICE、WMD六种评价指标的计算!

1. 代码地址

开源地址下载:https://github.com/ruotianluo/coco-caption

git clone https://github.com/ruotianluo/coco-caption.git
2. 下载spice需要的支持依赖 (1)安装model
bash get_stanford_models.sh

自动下载和解压。

(2) 安装java

spice的运行需要java支持,否则会报错:

FileNotFoundError: [Errno 2] No such file or directory: 'java'
  1. java的下载地址:https://www.oracle.com/java/technologies/downloads/#java8

    注意:这里要安装java1.8版本,不然会报错的!

    这里需要注册一个Oracle账号,亲测除了邮箱要填对(因为会发验证邮箱),其他乱填即可。

    因为我在linux下,所以选择:jdk-8u311_linux-x64_bin.tar.gz

  2. 下载好后,进入下载的地址,把文件copy到/opt下:

    sudo cp Downloads/jdk-8u311_linux-x64_bin.tar.gz /opt
    
  3. 给自己开权限:

    cd /opt
    sudo mkdir java
    sudo chown [user_name] java
    sudo chgrp [user_name] java
    

    注意,这里[user_name]是你linux的账户名。

  4. 解压:

    sudo tar -zxvf jdk-8u311_linux-x64_bin.tar.gz -C /opt/java
    
  5. 配置环境变量:

    sudo gedit /etc/profile
    

    如果是无屏幕界面,gedit换成vim即可。

  6. 追加如下信息:

    #set java environment
    export JAVA_HOME=/opt/java/jdk1.8.0_311
    export PATH=${JAVA_HOME}/bin:${PATH}
    

    保存退出后,更新一下:

    source /etc/profile
    

    但是这里我虽然当前terminal是可以显示java的版本的,但是新开一个terminal就显示不了java的版本了,还是路径索引没设置好,所以我在zshrc里也设置了一下:

    (注意!如果使用的是bash请用bashrc)

    sudo gedit ~/.zshrc
    

    把刚才添加的路径信息同样加在文件后面,然后保存退出,再更新一下:

    source ~/.zshrc
    

    ok,现在就可以正常找到java啦。

  7. 查看java是否安装成功:

    java -version
    
3. 下载WMD需要的库
bash get_google_word2vec_model.sh
4. DEMO

COCO-CAPTION里没有附demo,py文件,我给自己写了个,直接执行应该就可以啦:

# -*- coding=utf-8 -*-
# author: w61
# Test for several ways to compute the score of the generated words.
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.wmd.wmd import WMD

class Scorer():
    def __init__(self,ref,gt):
        self.ref = ref
        self.gt = gt
        print('setting up scorers...')
        self.scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr"),
            (Spice(), "SPICE"),
            (WMD(),   "WMD"),
        ]
    
    def compute_scores(self):
        total_scores = {}
        for scorer, method in self.scorers:
            print('computing %s score...'%(scorer.method()))
            score, scores = scorer.compute_score(self.gt, self.ref)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    print("%s: %0.3f"%(m, sc))
                total_scores["Bleu"] = score
            else:
                print("%s: %0.3f"%(method, score))
                total_scores[method] = score
        
        print('*****DONE*****')
        for key,value in total_scores.items():
            print('{}:{}'.format(key,value))

if __name__ == '__main__':
    ref = {
        '1':['go down the stairs and stop at the bottom .']
    }
    gt = {
        '1':['Walk down the steps and stop at the bottom. ', 'Go down the stairs and wait at the bottom.','once at the top of the stairway, walk down the spiral staircase all the way to the bottom floor. once you have left the stairs you are in a foyer and that indicates you are at your destination.']
    }
    scorer = Scorer(ref,gt)
    scorer.compute_scores()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/487314.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号