栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

知识图到文本的生成——贰

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

知识图到文本的生成——贰

2021SC@SDUSC

总体的代码结构如下:

经过我们的第二次小组会议、第三次小组会议讨论后,我们确定了关键代码为eval.py、generator.py、lastDataset.py、pargs.py、train.py、vectorize.py。而在第一次讨论后,我负责分析的关键代码为train.py、lastDataset.py、pargs.py,而此篇主要分析train.py中的部分代码。

整个train.py中共有4个函数,分别为update_lr、train、evaluate、main。

首先,我们先分析main(args)函数。

def main(args):
  try:
    os.stat(args.save)
    input("Save File Exists, OverWrite?  for no")
  except:
    os.mkdir(args.save)
  ds = dataset(args)
  args = dynArgs(args,ds)
  m = model(args)
  print(args.device)
  m = m.to(args.device)
  if args.ckpt:
    '''
    with open(args.save+"/commandLineArgs.txt") as f:
      clargs = f.read().strip().split("n") 
      argdif =[x for x in sys.argv[1:] if x not in clargs]
      assert(len(argdif)==2); 
      assert([x for x in argdif if x[0]=='-']==['-ckpt'])
    '''
    cpt = torch.load(args.ckpt)
    m.load_state_dict(cpt)
    starte = int(args.ckpt.split("/")[-1].split(".")[0])+1
    args.lr = float(args.ckpt.split("-")[-1])
    print('ckpt restored')
  else:
    with open(args.save+"/commandLineArgs.txt",'w') as f:
      f.write("n".join(sys.argv[1:]))
    starte=0
  o = torch.optim.SGD(m.parameters(),lr=args.lr, momentum=0.9)

  # early stopping based on Val Loss
  lastloss = 1000000
  
  for e in range(starte,args.epochs):
    print("epoch ",e,"lr",o.param_groups[0]['lr'])
    train(m,o,ds,args)
    vloss = evaluate(m,ds,args)
    if args.lrwarm:
      update_lr(o,args,e)
    print("Saving model")
    torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
    if vloss > lastloss:
      if args.lrdecay:
        print("decay lr")
        o.param_groups[0]['lr'] *= 0.5
    lastloss = vloss
  try:
    os.stat(args.save)
    input("Save File Exists, OverWrite?  for no")
  except:
    os.mkdir(args.save)

 根据程序的运行结果来看,运行“python train.py -save S”语句后,会将运行后的结果保存在名为“S”的文件夹中。如果说路径下有一个叫“S”的文件名,它会提示“Save File Exists, OverWrite?”,回车键即可重写进数据。若没有一个叫“S”的文件名,则会自动创建,用来保存数据。

  ds = dataset(args)
  args = dynArgs(args,ds)
  m = model(args)

后面我们定义了三个变量。dataset、dynArgs、model都是定义的类。首先我们看dataset类:

class dataset:

  def __init__(self, args):

__init__函数类似于C++中的构造函数,self为原始图实例,args为自定义参数。

    args.path = args.datadir + args.data
    print("Loading Data from ",args.path)
    self.args = args
    self.mkVocabs(args)
    print("Vocab sizes:")

这里是dataset的__init__类函数中的一部分,args.path即原始图的参数的路径为参数的数据路径+数据,在运行过程中显示“Loading Data from”+参数的路径。mkVocabs(args)是类dataset中的一个函数,用于构造文本。

下面我们来看mkVocabs(args)函数(其中一部分)。

  def mkVocabs(self,args):
    args.path = args.datadir + args.data
    self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.OUTP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.TGT = data.Field(sequential=True, batch_first=True,init_token="", eos_token="")
    self.NERD = data.Field(sequential=True, batch_first=True,eos_token="")
    self.ENT = data.RawField()
    self.REL = data.RawField()
    self.SORDER = data.RawField()
    self.SORDER.is_target = False
    self.REL.is_target = False 
    self.ENT.is_target = False 
    self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]

首先它对参数的路径进行了定义。

之后的INP、OUTP等等,都是对它各种属性进行赋值操作。而Field和RawField则是定义的两个类。每个数据集由一种或多种类型的数据组成。每种类型的数据都由一个RawField对象表示。RawField对象不采用数据类型和它包含与数据类型应如何处理相关的参数。

RawField类除了__init__函数外,还内含了两个函数,作为数据的预处理函数和处理函数。下篇博客将继续讨论。

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331253.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号