栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

知识图到文本的生成——伍

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

知识图到文本的生成——伍

2021SC@SDUSC

我们继续分析dataset类,dataset类位于lastDataset.py文件中,是该算法的核心代码之一。dataset类中一共有20个类函数,我将会挑选核心的函数来分析。

首先是对数据集建立词表的build_ent_vocab函数。

  def build_ent_vocab(self,path,unkat=0):
    ents = ""
    with open(path,encoding='utf-8') as f:
      for l in f:
        ents +=  " "+l.split("t")[1]
    itos = sorted(list(set(ents.split(" "))))
    itos[0] == ""; itos[1] == ""
    stoi = {x:i for i,x in enumerate(itos)}
    return itos,stoi

参数中的path就是数据集所在的路径,调用的时候传入。unkat参数初始值为0,意为为转换。ents是声明的字符串变量,存储遍历读取到的字符串数据集。itos是一个列表变量,每个元素都是ents中根据“ ”切割出的分词。比如ents='A  B',那么itos则为['A','B'],初始化itos第一个值为unk,第二个值为pad,enumerate()函数将itos组合为索引序列,结果组合为stoi变量。返回数据对象itos和索引序列stoi。

接下来是mkGraphs函数。

  def mkGraphs(self,r,ent):
    ……
    return (adj,rel)

这个函数的作用是用adj和rel矩阵将三元组转换为entlist。具体操作非关键代码,此处不再赘述。

接下来是mkVocabs函数。

  def mkVocabs(self,args):
    args.path = args.datadir + args.data
    self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.OUTP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.TGT = data.Field(sequential=True, batch_first=True,init_token="", eos_token="")
    self.NERD = data.Field(sequential=True, batch_first=True,eos_token="")
    self.ENT = data.RawField()
    self.REL = data.RawField()
    self.SORDER = data.RawField()
    self.SORDER.is_target = False
    self.REL.is_target = False 
    self.ENT.is_target = False 
    self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]

该段代码就是对这些参数进行操作,Field类和RawField类在之前已经详细分析过,此处不再单独分析这两个类。它设置了处理后保存的路径,设置INP和OUTP为顺序数据、先生成batch dimension的tensor、以“”为开始标记、以“”为结束标记、返回带填充的minibatch和的元组。

    if args.eval:
      train = data.TabularDataset(path=args.datadir+args.traindata, format='tsv',fields=self.fields)
    else:
      train = data.TabularDataset(path=args.path, format='tsv',fields=self.fields)

    print('building vocab')

train变量为把data定义为以TSV格式存储的列的数据集。TabularDataset是一个类,用来定义以CSV、TSV或JSON格式存储的列的数据集。如果使用dict,键应该是JSON键或CSV/TSV列的子集,值应该是(name, field)的元组。这会允许我们从其JSON/CSV/TSV键名重命名列,还允许选择要加载的列的子集。

    self.OUTP.build_vocab(train, min_freq=args.outunk)   
    generics =['','','','','']
    self.OUTP.vocab.itos.extend(generics)
    for x in generics:
      self.OUTP.vocab.stoi[x] = self.OUTP.vocab.itos.index(x)
    self.TGT.vocab = copy(self.OUTP.vocab)
    specials = "method material otherscientificterm metric task".split(" ")
    for x in specials:
      for y in range(40):
        s = "<"+x+"_"+str(y)+">"
        self.TGT.vocab.stoi[s] = len(self.TGT.vocab.itos)+y
    self.NERD.build_vocab(train,min_freq=0)
    for x in generics:
      self.NERD.vocab.stoi[x] = self.OUTP.vocab.stoi[x]

首先对要输出的变量进行build_vocab操作,该函数为Field的类函数,之前已分析过,此处不再赘述。generics是作者(不是我,是写代码的人)在数据集中找的一个实例。接下来就是对这个数据集进行扩大、切割、存储操作,specials就是把“method material otherscientificterm metric task”这个字符串根据" "进行分割,也就是generics。

接下来看一个批处理函数fixBatch()。

  def fixBatch(self,b):
    ent,phlens = zip(*b.ent)
    ent,elens = self.adjToBatch(ent)
    ent = ent.to(self.args.device)
    adj,rel = zip(*b.rel)
    if self.args.sparse:
      b.rel = [adj,self.listTo(rel)]
    else:
      b.rel = [self.listTo(adj),self.listTo(rel)]
    if self.args.plan:
      b.sordertgt = self.listTo(self.pad_list(b.sordertgt))
    phlens = torch.cat(phlens,0).to(self.args.device)
    elens = elens.to(self.args.device)
    b.ent = (ent,phlens,elens)
    return b

参数b为传入的地址。ent,phlens = zip(*b.ent)和adj,rel = zip(*b.rel)为解压b,解压后仍为元组,对解压后的元组调用adjToBatch函数进行生成邻接矩阵的批处理操作,最后返回的是矩阵。最后b直接变为三元组并返回。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580694.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号