栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

“基于常识知识的推理问题”源代码分析-迭代得到新事实集合

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

“基于常识知识的推理问题”源代码分析-迭代得到新事实集合

2021SC@SDUSC

这次源代码分析的内容是在概述中提到的模型算法的第二步,即对事实集合进行Fact-Follow迭代,从而得到新的事实集合,然后再根据新生成的事实集合,进行Fact-Follow迭代,从而得到新的事实集合,循环往复迭代进行t跳(hop)后,得到最终的事实集合。

毫无疑问,这次源代码分析的主要内容便在于何谓Fact-Follow迭代。因此,这次的源代码分析共分为两部分,一部分是对抽象的Fact_Follow迭代算法的详细描述,另一部分则是对DrFact模型的源代码的具体分析。

一. Fact-Follow迭代算法

要了解什么是Fact-Follow迭代算法,紧靠文字的三言两语很难说清,因此我们需要以下图的帮助,这能让我们更加直观的了解Fact-Follow迭代算法是如何进行的。

 1. 共现检索得到

从图中可见,Fact-Follow迭代函数是一个输入参数为问题q和前一跳所得事实集合(第一轮迭代时使用初始化的事实集合),输出本跳迭代新事实集合的函数。

据此,先对前一跳所得事实集合进行共现检索,即,得到。具体直观体现在算法图的左下方,红色的与绿色的矩阵S相乘,得到。

 2. 语义检索得到

下一步,如算法图左上方所示, 红色的与蓝色的矩阵D相乘,即,得到。

然后将和问题q作为函数g的输入参数,即,计算得到向量。

最后,将向量与矩阵D进行最大内积检索(MIPS),即,得到。

 3. 和对应位置相乘得到

最后一步,即是将和的对应位置相乘,即,得到最终的。 如算法图右侧所示,蓝色向量与绿色向量进行mixing,最终得到紫色向量,即本轮事实集合结果。

至此,一跳(一轮)Fact-Follow迭代完成,得到这一轮迭代的事实集合结果。

 二. DrFact模型具体源代码分析

 废话不多说,本篇源代码分析的对象是add_middle_hops.py文件。用到了以下这些python库。

 

import json

from absl import app
from absl import flags
from absl import logging
from tqdm import tqdm
import networkx as nx 
import os
import numpy as np
import itertools
from scipy import sparse
import tensorflow.compat.v1 as tf
from language.labs.drkit import search_utils
import pickle
from collections import defaultdict

 然后是照例通过flags对全局变量进行定义以及初始化。

FLAGS = flags.FLAGS

flags.DEFINE_string("linked_qas_file", None, "Path to dataset file.")
flags.DEFINE_string("drfact_format_gkb_file", None, "Path to gkb corpus.")
flags.DEFINE_string("sup_fact_result_without_ans", None, "Path to dataset file.")
flags.DEFINE_string("sup_fact_result_with_ans", None, "Path to dataset file.")
flags.DEFINE_string("f2f_index_file", None, "Path to dataset file.")
flags.DEFINE_string("f2f_nxgraph_file", None, "Path to dataset file.")
flags.DEFINE_string("output_file", None, "Path to dataset file.")
flags.DEFINE_string("do", None, "Path to dataset file.")

接下来分析函数preprare_fact2fact_network(),其作用是构建事实之间的关联矩阵S。

首先,使用os.path.join对输入的路径进行拼接,得到目标绝对路径f2f_checkpoint。

然后,指定tensorflow运行的CPU和GPU,加载f2f_checkpoint的meta文件中的超图,以及图上定义的结点参数,包括权重偏置项等需要训练的参数,也包括训练过程生成的中间参数,将加载的超图重新存储。

接下来,使用tensorflow的Session来运行fact2fact_data:0,fact2fact_indices:0和fact2fact_rowsplits:0,并将得到的结果分别存放于fact2fact_data,fact2fact_indices和fact2fact_rowsplits。

将上述三个矩阵结果,通过CSR采取按行压缩的办法, 将原始的矩阵用三个数组进行表示,存放于S中。使用nonzero()返回包含矩阵的非零元素索引的数组(行、列)的元组,存放于row和col中。然后将这两个可迭代的对象row和col进行zip()处理,得到一一对应的元组列表,再使用tqdm()传入遍历(根据adding edges进行降序排序),存入node_out_dict和node_in_dict两个字典中。其对应意义是node_out_dict字典中存放有出节点,而node_out_dict存放有入节点。

最后,将处理好的node_out_dict和node_in_dict两个字典使用pickle.dumps()进行序列化处理,存入相应的.indict和.outdict文件之中。

def preprare_fact2fact_network():
  """Loads the f2f data."""
  f2f_checkpoint = os.path.join(FLAGS.f2f_index_file)
  with tf.device("/cpu:0"):
    with tf.Graph().as_default():
      logging.info("Reading %s", f2f_checkpoint) 
      with tf.Session() as sess: 
        new_saver = tf.train.import_meta_graph(f2f_checkpoint+'.meta')
        new_saver.restore(sess, f2f_checkpoint)
        fact2fact_data = sess.run('fact2fact_data:0')
        fact2fact_indices = sess.run('fact2fact_indices:0')
        fact2fact_rowsplits = sess.run('fact2fact_rowsplits:0')
        S = sparse.csr_matrix((fact2fact_data, fact2fact_indices, fact2fact_rowsplits))
  row, col = S.nonzero()
  f2f_nxgraph = nx.DiGraph()
  node_in_dict = defaultdict(set)
  node_out_dict = defaultdict(set)
  for f_i, f_j in tqdm(list(zip(row, col)), desc="adding edges"):    
    node_out_dict[int(f_i)].add(int(f_j))
    node_in_dict[int(f_j)].add(int(f_i))
    # f2f_nxgraph.add_edge(int(f_i), int(f_j))
  
  with open(FLAGS.f2f_nxgraph_file+".indict", "wb") as f:
    pickle.dump(dict(node_in_dict), f)
  with open(FLAGS.f2f_nxgraph_file+".outdict", "wb") as f:
    pickle.dump(dict(node_out_dict), f)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/444984.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号