2021SC@SDUSC
这次源代码分析的内容是在概述中提到的模型算法的第二步,即对事实集合进行Fact-Follow迭代,从而得到新的事实集合,然后再根据新生成的事实集合,进行Fact-Follow迭代,从而得到新的事实集合,循环往复迭代进行t跳(hop)后,得到最终的事实集合。
毫无疑问,这次源代码分析的主要内容便在于何谓Fact-Follow迭代。因此,这次的源代码分析共分为两部分,一部分是对抽象的Fact_Follow迭代算法的详细描述,另一部分则是对DrFact模型的源代码的具体分析。
一. Fact-Follow迭代算法要了解什么是Fact-Follow迭代算法,紧靠文字的三言两语很难说清,因此我们需要以下图的帮助,这能让我们更加直观的了解Fact-Follow迭代算法是如何进行的。
1. 共现检索得到从图中可见,Fact-Follow迭代函数是一个输入参数为问题q和前一跳所得事实集合(第一轮迭代时使用初始化的事实集合),输出本跳迭代新事实集合的函数。
据此,先对前一跳所得事实集合进行共现检索,即,得到。具体直观体现在算法图的左下方,红色的与绿色的矩阵S相乘,得到。
2. 语义检索得到下一步,如算法图左上方所示, 红色的与蓝色的矩阵D相乘,即,得到。
然后将和问题q作为函数g的输入参数,即,计算得到向量。
最后,将向量与矩阵D进行最大内积检索(MIPS),即,得到。
3. 和对应位置相乘得到最后一步,即是将和的对应位置相乘,即,得到最终的。 如算法图右侧所示,蓝色向量与绿色向量进行mixing,最终得到紫色向量,即本轮事实集合结果。
至此,一跳(一轮)Fact-Follow迭代完成,得到这一轮迭代的事实集合结果。
二. DrFact模型具体源代码分析废话不多说,本篇源代码分析的对象是add_middle_hops.py文件。用到了以下这些python库。
import json from absl import app from absl import flags from absl import logging from tqdm import tqdm import networkx as nx import os import numpy as np import itertools from scipy import sparse import tensorflow.compat.v1 as tf from language.labs.drkit import search_utils import pickle from collections import defaultdict
然后是照例通过flags对全局变量进行定义以及初始化。
FLAGS = flags.FLAGS
flags.DEFINE_string("linked_qas_file", None, "Path to dataset file.")
flags.DEFINE_string("drfact_format_gkb_file", None, "Path to gkb corpus.")
flags.DEFINE_string("sup_fact_result_without_ans", None, "Path to dataset file.")
flags.DEFINE_string("sup_fact_result_with_ans", None, "Path to dataset file.")
flags.DEFINE_string("f2f_index_file", None, "Path to dataset file.")
flags.DEFINE_string("f2f_nxgraph_file", None, "Path to dataset file.")
flags.DEFINE_string("output_file", None, "Path to dataset file.")
flags.DEFINE_string("do", None, "Path to dataset file.")
接下来分析函数preprare_fact2fact_network(),其作用是构建事实之间的关联矩阵S。
首先,使用os.path.join对输入的路径进行拼接,得到目标绝对路径f2f_checkpoint。
然后,指定tensorflow运行的CPU和GPU,加载f2f_checkpoint的meta文件中的超图,以及图上定义的结点参数,包括权重偏置项等需要训练的参数,也包括训练过程生成的中间参数,将加载的超图重新存储。
接下来,使用tensorflow的Session来运行fact2fact_data:0,fact2fact_indices:0和fact2fact_rowsplits:0,并将得到的结果分别存放于fact2fact_data,fact2fact_indices和fact2fact_rowsplits。
将上述三个矩阵结果,通过CSR采取按行压缩的办法, 将原始的矩阵用三个数组进行表示,存放于S中。使用nonzero()返回包含矩阵的非零元素索引的数组(行、列)的元组,存放于row和col中。然后将这两个可迭代的对象row和col进行zip()处理,得到一一对应的元组列表,再使用tqdm()传入遍历(根据adding edges进行降序排序),存入node_out_dict和node_in_dict两个字典中。其对应意义是node_out_dict字典中存放有出节点,而node_out_dict存放有入节点。
最后,将处理好的node_out_dict和node_in_dict两个字典使用pickle.dumps()进行序列化处理,存入相应的.indict和.outdict文件之中。
def preprare_fact2fact_network():
"""Loads the f2f data."""
f2f_checkpoint = os.path.join(FLAGS.f2f_index_file)
with tf.device("/cpu:0"):
with tf.Graph().as_default():
logging.info("Reading %s", f2f_checkpoint)
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(f2f_checkpoint+'.meta')
new_saver.restore(sess, f2f_checkpoint)
fact2fact_data = sess.run('fact2fact_data:0')
fact2fact_indices = sess.run('fact2fact_indices:0')
fact2fact_rowsplits = sess.run('fact2fact_rowsplits:0')
S = sparse.csr_matrix((fact2fact_data, fact2fact_indices, fact2fact_rowsplits))
row, col = S.nonzero()
f2f_nxgraph = nx.DiGraph()
node_in_dict = defaultdict(set)
node_out_dict = defaultdict(set)
for f_i, f_j in tqdm(list(zip(row, col)), desc="adding edges"):
node_out_dict[int(f_i)].add(int(f_j))
node_in_dict[int(f_j)].add(int(f_i))
# f2f_nxgraph.add_edge(int(f_i), int(f_j))
with open(FLAGS.f2f_nxgraph_file+".indict", "wb") as f:
pickle.dump(dict(node_in_dict), f)
with open(FLAGS.f2f_nxgraph_file+".outdict", "wb") as f:
pickle.dump(dict(node_out_dict), f)



