栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Tensorflow2.x 保存与回溯训练集状态的若干方法探究

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Tensorflow2.x 保存与回溯训练集状态的若干方法探究

目录
  • 1.目的
  • 2.可行性与难点
    • 2.1 记录迭代状态
    • 2.2 兼容tf.py_function
    • 2.3 兼容batch()和prefetch()
  • 3. 方案与对比
    • 3.1 方案1 − − -- −−基于tf.data.Dataset.from_tensor_slices 构建数据集
    • 3.2 方案2 − − -- −−基于Generator和tf.data.Dataset.from_generator构建数据集
    • 3.3 方案对比

1.目的

我们希望能够在训练过程中,保存训练集的状态,无论何时意外地中断训练,再重启恢复训练时,训练结果完全一致。当然,一个最简单的做法是为训练循环加上分支,越过已经训练的数据集部分:

for step,datas in zip(range(steps),dataset):
    if step <= saved_step:
        pass 
    else:
        ... # train body

这是最轻量的coding方式,不影响已经构建的数据管道,不影响训练函数体,可以保证达到试验目的,在实验结果上,和我们的希望效果是一致的、等价的,唯一的缺点是低效,当数据管道无法将数据全部同时保存在本机的内存中,需要从本地磁盘或者网络设备等其余地方实时读取时,所需要的计算资源就被严重地浪费了。我们希望找到一个高效的数据集回溯方式,避免计算资源的浪费

数据集回溯:
训练集可能是一个有限集,也可能是一个无限集,我们在构建训练集时,还会加入组合,随机打乱,筛选,采样等等操作,但不论是否是分布式训练,对于一个确定的训练载体, 在确定的epoch和全局step下,其当前从训练集中获得的数据是确定的。一个简单的例子,以 [0,1,2,3,4,5,6,7,8,9] 为基础构建一个数据集,加入了一点随机打乱,每一步都保存,那么可以有如下的代码:

import tensorflow as tf 
import os 
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
datas = {"test":[str(item) for item in  range(10)]} 
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
    def py_func(inp):
        return int(inp.numpy())
    k = list(x.keys())
    v = list(x.values())
    y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
    return dict(zip(k,y))
dataset = tf.data.Dataset.from_tensor_slices(datas)
dataset = dataset.shuffle(dataset.cardinality(),seed=0).map(mapfunc)
with tempfile.TemporaryDirectory() as dir_name:
    checkpoint_directory = dir_name
    print(dir_name)
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    checkpoint = tf.train.Checkpoint(dataset=dataset)
    epoch = 1
    while epoch<=5:
        buf = []
        for i,item in enumerate(dataset):
            if (epoch == 3)and(i == 5):
                epoch = float("inf")
                break
            buf.append(item["test"].numpy())
            checkpoint.save(file_prefix=checkpoint_prefix)
        print(buf)
        epoch += 1

根据以上的代码,每个epoch的数据集内容如下表所示:

epchdataset content
1[2.0, 0.0, 1.0, 4.0, 5.0, 6.0, 9.0, 7.0, 8.0, 3.0]
2[9.0, 3.0, 6.0, 5.0, 7.0, 2.0, 1.0, 8.0, 0.0, 4.0]
3[7.0, 3.0, 6.0, 8.0, 2.0] → color{red}{rightarrow} →[0.0, 4.0, 9.0, 1.0, 5.0]
4[4.0, 3.0, 2.0, 8.0, 6.0, 7.0, 9.0, 5.0, 0.0, 1.0]

其中 → color{red}{rightarrow} →之后的内容没有被实际运用,因为在 e p o c h = 3 rm epoch=3 epoch=3,全局 s t e p = 2 × 10 + 5 rm step=2times10+5 step=2×10+5时尚未保存就终止了。我们希望训练重启恢复时,训练集可以直接回溯到此位置,即从 e p o c h = 3 rm epoch=3 epoch=3,全局 s t e p = 2 × 10 + 5 rm step=2times10+5 step=2×10+5的位置开始训练。这就是训练集状态的精确回溯。

P.S. 如果要模型的训练结果完全一致,不仅仅是数据集状态要精确回溯,还需要确定性算法的加持,即tf.config.experimental.enable_op_determinism 本文默认大家都已经了解该内容,以下只讨论数据集状态的精确回溯问题。

2.可行性与难点

众所周知,tf.data.Dataset(以下记为Dataset)由于继承自Trackable,自然是可以被tf.train.Checkpoint(以下记为Checkpoint)保存的。但是,在实际操作时,我们会遇到如下难点:

2.1 记录迭代状态

Dataset 本身可以记录随机种子的状态,但不记录迭代状态,我们在Checkpoint中需要保存的是基于数据集的迭代器,而非数据集本身,否则,重启训练构建管道时,只能回溯到某个epoch,然后从下一个epoch第一个元素开始迭代。正确的做法是每次构建基于数据集的迭代器,保存和读取该迭代器,可以参考官方教程中Iterator Checkpointing的内容。

2.2 兼容tf.py_function

若Dataset的构建过程基于tf.py_function,那无法被Checkpoint保存,需要将Dataset拆解成级联的dataset_B(dataset_A)两部分,将tf.py_function放入B中,只保存读取dataset_A

2.3 兼容batch()和prefetch()

Dataset.batch()和Dataset.prefetch()会带来意外的结果,比如如下代码,当我们将dataset拆解成dataset_B(dataset_A)两部分,dataset_A基于Dataset.from_tensor_slices(), dataset_B基于Dataset.from_generator()和dataset_A,如果构建数据集管道时有Dataset.batch()或者Dataset.prefetch()的需要,恢复后会丢失保存前的尾端数据。比如如下代码:

import tensorflow as tf 
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

datas = {"test":[str(item) for item in  range(100)]} 
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
    def py_func(inp):
        return int(inp.numpy())
    k = list(x.keys())
    v = list(x.values())
    y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
    return dict(zip(k,y))

# buf1-(无batch,无perfetch)
with tempfile.TemporaryDirectory() as dir_name:
    dataset_A = tf.data.Dataset.from_tensor_slices(datas)
    def wrapper(iterator):
        def gen():
            yield from iterator
        dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
        return dataset_B # 无batch prefetch
    step = tf.Variable(0)
    iterator = iter(dataset_A) 
    checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
    buf1 = []
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf1.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
        if step.numpy()>=13:
            break
    ckpt_manager.restore_or_initialize()
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf1.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
    print(buf1)

#buf2-(batch,perfetch)
with tempfile.TemporaryDirectory() as dir_name:
    dataset_A = tf.data.Dataset.from_tensor_slices(datas)
    def wrapper(iterator):
        def gen():
            yield from iterator
        dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
        return dataset_B.batch(1).prefetch(tf.data.AUTOTUNE) # batch prefetch
    step = tf.Variable(0)
    iterator = iter(dataset_A) 
    checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
    buf2 = []
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf2.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
        if step.numpy()>=13:
            break
    ckpt_manager.restore_or_initialize()
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf2.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
    print(buf2)

#buf3-(determinism)(batch,perfetch)
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism() # determinism
with tempfile.TemporaryDirectory() as dir_name:
    dataset_A = tf.data.Dataset.from_tensor_slices(datas)
    def wrapper(iterator):
        def gen():
            yield from iterator
        dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
        return dataset_B.batch(1).prefetch(tf.data.AUTOTUNE)
    step = tf.Variable(0)
    iterator = iter(dataset_A) 
    checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
    buf3 = []
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf3.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
        if step.numpy()>=13:
            break
    ckpt_manager.restore_or_initialize()
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        buf3.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
        ckpt_manager.save(check_interval=True,checkpoint_number=step)
    print(buf3)

以上代码中做了三次对比实验,分别记录仪在buf1,buf2,buf3中,buf1记录了不使用batch和prefetch方法的数据集内容,buf2记录了正常情况下使用batch和prefetch方法处理后的数据集内容,buf3记录了在确定性算法加持下,使用batch和prefetch方法处理后的数据集内容。我们将三次结果展开记录到下表中,Checkpoint的记录与读取以及程序的中断位置都在表中进行了标注。

buf1-(无batch,无prefetch)buf2-(batch,prefetch)buf3-(determinism)(batch,prefetch)
(1, 0.0)*记录位置1(1, 0.0)*记录位置1(1, 0.0)*记录位置1
(2, 1.0)(2, 1.0)(2, 1.0)
(3, 2.0)(3, 2.0)(3, 2.0)
(4, 3.0)(4, 3.0)(4, 3.0)
(5, 4.0)(5, 4.0)(5, 4.0)
(6, 5.0)(6, 5.0)(6, 5.0)
(7, 6.0)(7, 6.0)(7, 6.0)
(8, 7.0)(8, 7.0)(8, 7.0)
(9, 8.0)(9, 8.0)(9, 8.0)
(10, 9.0)(10, 9.0)(10, 9.0)
(11, 10.0) *记录位置2(11, 10.0) *记录位置2(11, 10.0) *记录位置2
(12, 11.0)(12, 11.0)(12, 11.0)
(13, 12.0) *中断位置(13, 12.0) *中断位置(13, 12.0) *中断位置
(12, 11.0) *读取位置(12, 13.0) *读取位置(12, 11.0) *读取位置
(13, 12.0)(13, 14.0)(13, 12.0)
(14, 13.0)(14, 15.0)(14, 13.0)
(15, 14.0)(15, 16.0)(15, 14.0)
(16, 15.0)(16, 17.0)(16, 15.0)
(17, 16.0)(17, 18.0)(17, 16.0)
(18, 17.0)(18, 19.0)(18, 17.0)
(19, 18.0)(19, 20.0)(19, 18.0)
(20, 19.0)(20, 21.0)(20, 19.0)
(21, 20.0) *记录位置3(21, 22.0) *记录位置3(21, 20.0) *记录位置3
(22, 21.0)(22, 23.0)(22, 21.0)
(23, 22.0)(23, 24.0)(23, 22.0)
(24, 23.0)(24, 25.0)(24, 23.0)
(25, 24.0)(25, 26.0)(25, 24.0)

buf1是完全正确且符合我们预期的数据集内容,将其作为baseline, 可以发现,中断位置前的数据集内容在buf1,buf2和buf3中是一致的,但是从读取位置开始,buf2记录的数据集内容就与其他的不同了,这就是Dataset的一个bug,如果没有确定性算法,使用batch和prefetch方法会导致再次迭代时数据丢失,这在确定性算法(buf3)下是没有的。为了避免该bug, 我们必须以确定性算法展开实验,否则每次中断后重启实验,数据集内容会丢失一部分,为实验带来额外的变数。

P.S. buf2代表的bug是很极端的情况,只有我们先构建级联的dataset时才会出现,且和tf.config.experimental.enable_op_determinism()这一实验性内容有关,因此该bug短期被直接修复的可能性不大。

3. 方案与对比

综合上述内容,这里给出两种切实可行的保存与回溯训练集状态的方案。

3.1 方案1 − − -- −−基于tf.data.Dataset.from_tensor_slices 构建数据集
import tensorflow as tf 
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

datas = {"test":[str(item) for item in  range(100)]} 
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
    def py_func(inp):
        return int(inp.numpy())
    k = list(x.keys())
    v = list(x.values())
    y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
    return dict(zip(k,y))
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()
with tempfile.TemporaryDirectory() as dir_name:
    dataset_A = tf.data.Dataset.from_tensor_slices(datas).shuffle(buffer_size=100,seed=0)
    def wrapper(iterator):
        def gen():
            yield from iterator
        dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
        return dataset_B
    step = tf.Variable(0)
    iterator = iter(dataset_A) 
    checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
    ckpt_manager.restore_or_initialize()
    for s,item  in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
        step.assign(s)
        # train_step()
        ckpt_manager.save(check_interval=True,checkpoint_number=step)

其中shuffle() map() batch()和prefetch()可以按照个人需求增减,但若使用了batch()和prefetch(), 就必须以确定性算法tf.config.experimental.enable_op_determinism()为前提, 以保证数据集可以精确回溯到上一次保存的状态。如果有tf.py_function的使用需要,必须将数据集拆分成至少两部分,只能保存和回溯不需要tf.py_function的部分。shuffle()和map()的次序是不强制的,如果数据集被拆分成多部分,shuffle()应当仅存在于被保存的部分。

3.2 方案2 − − -- −−基于Generator和tf.data.Dataset.from_generator构建数据集
import tempfile
import copy
import random 
import logging
from typeguard import typechecked
import itertools
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def get_random_from_seed(seed:int|None=None):
    return random.Random(seed) if seed is not None else None
def random_datas(datas:list,random:random.Random|None=None):
    if random is not None:
        random.shuffle(datas) # since datas has been shuffled, the next shuffle will not be the same
        logging.info(f"Random complete!,the first data is {datas[0]}")
    return datas
class DataIter():
    @typechecked
    def __init__(self,datas:list,counters:dict[Literal["step","epoch"],tf.Variable],seed:int|None=None) -> None:
        # self.count = 0
        self._epoch = counters["epoch"] 
        self._step = counters["step"] 
        self._datas = datas
        
        self.seed = seed
        self._check()
    def _check(self):
        assert (self.step//self.length)==self.epoch
    @property
    def epoch(self):
        return self._epoch.numpy()
    @property
    def step(self):
        return self._step.numpy()
    @property
    def datas(self):
        return copy.deepcopy(self._datas) # do not influence original data
    @property
    def length(self):
        return len(self.datas)
    def __iter__(self):
        datas = self.datas
        random = get_random_from_seed(self.seed)
        for _ in range(self.epoch+1):
            datas = random_datas(datas,random)
        # self.step%len(self.datas) do not need 'minus 1', because it is the start index exactly
        return itertools.islice(datas,self.step%self.length,self.length) 
    def __repr__(self) -> str:
        return f"Static indices is epoch:{self.epoch} step:{self.step}."
    
datas = [{"test":str(item)} for item in  range(100)]
def mapfunc(x):
    k = list(x.keys())
    v = list(x.values())
    y = list(map(lambda inp:int(inp),v))
    return dict(zip(k,y))
with tempfile.TemporaryDirectory() as dir_name:
    step = tf.Variable(0)
    epoch = tf.Variable(0)
    checkpoint = tf.train.Checkpoint(step=step,epoch=epoch)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
    counters = {"step":step,"epoch":epoch}
    di  = DataIter(datas,counters=counters,seed=0)
    def generator(): # can be overwrited 
        for item in di:
            yield mapfunc(item)
    dataset = tf.data.Dataset.from_generator(generator,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.float32)}))
    for e in range(epoch.numpy(),5+1):
        epoch.assign(e)
        for s,item  in zip(range(step.numpy()+1,25+1),dataset):
            step.assign(s)
            ckpt_manager.save(check_interval=True,checkpoint_number=step)

首先构建DataIter(),受全局的epoch和step影响,让独立的random.Random()管理shuffle()功能,而不是由Dataset()管理, 由于random本身会因为每次shuffle()而改变状态,而该状态无法由Checkpoint保存,因此DataIter()只能在每次iter()时 推理出上一次的状态,这是目前的无奈之举

3.3 方案对比

方案1

  • 优点:
    1. shuffle可以直接由tf.data.Dataset()管理,shuffle()状态可以由checkpoint保存和读取,不需要迭代推理
  • 缺点:
    1. shuffle() seed不完全受控,由于TensorFlow的random行为由“全局种子+操作种子”同时控制,存在潜在的意外shuffle()行为, 一旦代码中设置全局种子的行为,就会影响dataset
    2. 不能直接支持tf.py_function,必须将数据集手动拆分组合,且以确定性算法为前提,才能实现数据集内容的精确回溯
    3. 使用时,构建数据集后,必须再构建基于数据集的迭代器,若数据集被拆分成级联的两部分,每次使用,还需要基于该迭代器再次构建数据集,不方便且容易出错

方案2

  • 优点:
    1. generator()可以为纯python代码,如果出现方案1中必须要使用tf.py_function的地方,可以直接以纯python代码代替,极大地简化了这部分内容
    2. shuffle()行为可以由独立的random.Random()完全管理,完全受seed和shuffle()次数控制随机行为,不依赖于TensorFlow的random行为,因此不受框架限制
    3. 使用时比较方便,因为generator()中是对DataIter()实例的引用,Checkpoint读取时覆盖了全局step和epoch,相当于向DataIter()实例传递了回溯信息,因此不需要重新构建DataIter()实例
  • 缺点:
    1. 必须自行构建DataIter()或者类似功能的类,以实现shuffle()功能,且无法保存random状态至Checkpoint中,每次iter()时需要从头开始推理状态,浪费了一部分计算资源

我们认为,方案1和方案2都是可行方案,或者说,都是存在硬性缺陷的无奈之举,二者集体摆烂,所以没有明显优劣之分,只能说,目前为了实现保存与回溯训练集状态的目的,方案1和方案2都是可行之举。若随着TensorFlow2.x版本更新,方案1和方案2中存在的问题得以解决,本文将持续更新。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/859151.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号