2021SC@SDUSC
总体的代码结构如下:
经过我们的第二次小组会议、第三次小组会议讨论后,我们确定了关键代码为eval.py、generator.py、lastDataset.py、pargs.py、train.py、vectorize.py。而在第一次讨论后,我负责分析的关键代码为train.py、lastDataset.py、pargs.py,而此篇主要分析train.py中的部分代码。
整个train.py中共有4个函数,分别为update_lr、train、evaluate、main。
首先,我们先分析main(args)函数。
def main(args):
try:
os.stat(args.save)
input("Save File Exists, OverWrite? for no")
except:
os.mkdir(args.save)
ds = dataset(args)
args = dynArgs(args,ds)
m = model(args)
print(args.device)
m = m.to(args.device)
if args.ckpt:
'''
with open(args.save+"/commandLineArgs.txt") as f:
clargs = f.read().strip().split("n")
argdif =[x for x in sys.argv[1:] if x not in clargs]
assert(len(argdif)==2);
assert([x for x in argdif if x[0]=='-']==['-ckpt'])
'''
cpt = torch.load(args.ckpt)
m.load_state_dict(cpt)
starte = int(args.ckpt.split("/")[-1].split(".")[0])+1
args.lr = float(args.ckpt.split("-")[-1])
print('ckpt restored')
else:
with open(args.save+"/commandLineArgs.txt",'w') as f:
f.write("n".join(sys.argv[1:]))
starte=0
o = torch.optim.SGD(m.parameters(),lr=args.lr, momentum=0.9)
# early stopping based on Val Loss
lastloss = 1000000
for e in range(starte,args.epochs):
print("epoch ",e,"lr",o.param_groups[0]['lr'])
train(m,o,ds,args)
vloss = evaluate(m,ds,args)
if args.lrwarm:
update_lr(o,args,e)
print("Saving model")
torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
if vloss > lastloss:
if args.lrdecay:
print("decay lr")
o.param_groups[0]['lr'] *= 0.5
lastloss = vloss
try:
os.stat(args.save)
input("Save File Exists, OverWrite? for no")
except:
os.mkdir(args.save)
根据程序的运行结果来看,运行“python train.py -save S”语句后,会将运行后的结果保存在名为“S”的文件夹中。如果说路径下有一个叫“S”的文件名,它会提示“Save File Exists, OverWrite?”,回车键即可重写进数据。若没有一个叫“S”的文件名,则会自动创建,用来保存数据。
ds = dataset(args) args = dynArgs(args,ds) m = model(args)
后面我们定义了三个变量。dataset、dynArgs、model都是定义的类。首先我们看dataset类:
class dataset: def __init__(self, args):
__init__函数类似于C++中的构造函数,self为原始图实例,args为自定义参数。
args.path = args.datadir + args.data
print("Loading Data from ",args.path)
self.args = args
self.mkVocabs(args)
print("Vocab sizes:")
这里是dataset的__init__类函数中的一部分,args.path即原始图的参数的路径为参数的数据路径+数据,在运行过程中显示“Loading Data from”+参数的路径。mkVocabs(args)是类dataset中的一个函数,用于构造文本。
下面我们来看mkVocabs(args)函数(其中一部分)。
def mkVocabs(self,args):
args.path = args.datadir + args.data
self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
self.OUTP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
self.TGT = data.Field(sequential=True, batch_first=True,init_token="", eos_token="")
self.NERD = data.Field(sequential=True, batch_first=True,eos_token="")
self.ENT = data.RawField()
self.REL = data.RawField()
self.SORDER = data.RawField()
self.SORDER.is_target = False
self.REL.is_target = False
self.ENT.is_target = False
self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]
首先它对参数的路径进行了定义。
之后的INP、OUTP等等,都是对它各种属性进行赋值操作。而Field和RawField则是定义的两个类。每个数据集由一种或多种类型的数据组成。每种类型的数据都由一个RawField对象表示。RawField对象不采用数据类型和它包含与数据类型应如何处理相关的参数。
RawField类除了__init__函数外,还内含了两个函数,作为数据的预处理函数和处理函数。下篇博客将继续讨论。



