- 1.目的
- 2.可行性与难点
- 2.1 记录迭代状态
- 2.2 兼容tf.py_function
- 2.3 兼容batch()和prefetch()
- 3. 方案与对比
- 3.1 方案1 − − -- −−基于tf.data.Dataset.from_tensor_slices 构建数据集
- 3.2 方案2 − − -- −−基于Generator和tf.data.Dataset.from_generator构建数据集
- 3.3 方案对比
我们希望能够在训练过程中,保存训练集的状态,无论何时意外地中断训练,再重启恢复训练时,训练结果完全一致。当然,一个最简单的做法是为训练循环加上分支,越过已经训练的数据集部分:
for step,datas in zip(range(steps),dataset):
if step <= saved_step:
pass
else:
... # train body
这是最轻量的coding方式,不影响已经构建的数据管道,不影响训练函数体,可以保证达到试验目的,在实验结果上,和我们的希望效果是一致的、等价的,唯一的缺点是低效,当数据管道无法将数据全部同时保存在本机的内存中,需要从本地磁盘或者网络设备等其余地方实时读取时,所需要的计算资源就被严重地浪费了。我们希望找到一个高效的数据集回溯方式,避免计算资源的浪费
数据集回溯:
训练集可能是一个有限集,也可能是一个无限集,我们在构建训练集时,还会加入组合,随机打乱,筛选,采样等等操作,但不论是否是分布式训练,对于一个确定的训练载体, 在确定的epoch和全局step下,其当前从训练集中获得的数据是确定的。一个简单的例子,以 [0,1,2,3,4,5,6,7,8,9] 为基础构建一个数据集,加入了一点随机打乱,每一步都保存,那么可以有如下的代码:
import tensorflow as tf
import os
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
datas = {"test":[str(item) for item in range(10)]}
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
def py_func(inp):
return int(inp.numpy())
k = list(x.keys())
v = list(x.values())
y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
return dict(zip(k,y))
dataset = tf.data.Dataset.from_tensor_slices(datas)
dataset = dataset.shuffle(dataset.cardinality(),seed=0).map(mapfunc)
with tempfile.TemporaryDirectory() as dir_name:
checkpoint_directory = dir_name
print(dir_name)
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tf.train.Checkpoint(dataset=dataset)
epoch = 1
while epoch<=5:
buf = []
for i,item in enumerate(dataset):
if (epoch == 3)and(i == 5):
epoch = float("inf")
break
buf.append(item["test"].numpy())
checkpoint.save(file_prefix=checkpoint_prefix)
print(buf)
epoch += 1
根据以上的代码,每个epoch的数据集内容如下表所示:
| epch | dataset content |
|---|---|
| 1 | [2.0, 0.0, 1.0, 4.0, 5.0, 6.0, 9.0, 7.0, 8.0, 3.0] |
| 2 | [9.0, 3.0, 6.0, 5.0, 7.0, 2.0, 1.0, 8.0, 0.0, 4.0] |
| 3 | [7.0, 3.0, 6.0, 8.0, 2.0] → color{red}{rightarrow} →[0.0, 4.0, 9.0, 1.0, 5.0] |
| 4 | [4.0, 3.0, 2.0, 8.0, 6.0, 7.0, 9.0, 5.0, 0.0, 1.0] |
| … | … |
其中 → color{red}{rightarrow} →之后的内容没有被实际运用,因为在 e p o c h = 3 rm epoch=3 epoch=3,全局 s t e p = 2 × 10 + 5 rm step=2times10+5 step=2×10+5时尚未保存就终止了。我们希望训练重启恢复时,训练集可以直接回溯到此位置,即从 e p o c h = 3 rm epoch=3 epoch=3,全局 s t e p = 2 × 10 + 5 rm step=2times10+5 step=2×10+5的位置开始训练。这就是训练集状态的精确回溯。
P.S. 如果要模型的训练结果完全一致,不仅仅是数据集状态要精确回溯,还需要确定性算法的加持,即tf.config.experimental.enable_op_determinism 本文默认大家都已经了解该内容,以下只讨论数据集状态的精确回溯问题。
2.可行性与难点众所周知,tf.data.Dataset(以下记为Dataset)由于继承自Trackable,自然是可以被tf.train.Checkpoint(以下记为Checkpoint)保存的。但是,在实际操作时,我们会遇到如下难点:
2.1 记录迭代状态Dataset 本身可以记录随机种子的状态,但不记录迭代状态,我们在Checkpoint中需要保存的是基于数据集的迭代器,而非数据集本身,否则,重启训练构建管道时,只能回溯到某个epoch,然后从下一个epoch第一个元素开始迭代。正确的做法是每次构建基于数据集的迭代器,保存和读取该迭代器,可以参考官方教程中Iterator Checkpointing的内容。
2.2 兼容tf.py_function若Dataset的构建过程基于tf.py_function,那无法被Checkpoint保存,需要将Dataset拆解成级联的dataset_B(dataset_A)两部分,将tf.py_function放入B中,只保存读取dataset_A
2.3 兼容batch()和prefetch()Dataset.batch()和Dataset.prefetch()会带来意外的结果,比如如下代码,当我们将dataset拆解成dataset_B(dataset_A)两部分,dataset_A基于Dataset.from_tensor_slices(), dataset_B基于Dataset.from_generator()和dataset_A,如果构建数据集管道时有Dataset.batch()或者Dataset.prefetch()的需要,恢复后会丢失保存前的尾端数据。比如如下代码:
import tensorflow as tf
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
datas = {"test":[str(item) for item in range(100)]}
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
def py_func(inp):
return int(inp.numpy())
k = list(x.keys())
v = list(x.values())
y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
return dict(zip(k,y))
# buf1-(无batch,无perfetch)
with tempfile.TemporaryDirectory() as dir_name:
dataset_A = tf.data.Dataset.from_tensor_slices(datas)
def wrapper(iterator):
def gen():
yield from iterator
dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
return dataset_B # 无batch prefetch
step = tf.Variable(0)
iterator = iter(dataset_A)
checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
buf1 = []
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf1.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
if step.numpy()>=13:
break
ckpt_manager.restore_or_initialize()
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf1.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
print(buf1)
#buf2-(batch,perfetch)
with tempfile.TemporaryDirectory() as dir_name:
dataset_A = tf.data.Dataset.from_tensor_slices(datas)
def wrapper(iterator):
def gen():
yield from iterator
dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
return dataset_B.batch(1).prefetch(tf.data.AUTOTUNE) # batch prefetch
step = tf.Variable(0)
iterator = iter(dataset_A)
checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
buf2 = []
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf2.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
if step.numpy()>=13:
break
ckpt_manager.restore_or_initialize()
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf2.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
print(buf2)
#buf3-(determinism)(batch,perfetch)
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism() # determinism
with tempfile.TemporaryDirectory() as dir_name:
dataset_A = tf.data.Dataset.from_tensor_slices(datas)
def wrapper(iterator):
def gen():
yield from iterator
dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
return dataset_B.batch(1).prefetch(tf.data.AUTOTUNE)
step = tf.Variable(0)
iterator = iter(dataset_A)
checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
buf3 = []
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf3.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
if step.numpy()>=13:
break
ckpt_manager.restore_or_initialize()
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
buf3.append((step.numpy(),tf.reduce_mean(item["test"]).numpy()))
ckpt_manager.save(check_interval=True,checkpoint_number=step)
print(buf3)
以上代码中做了三次对比实验,分别记录仪在buf1,buf2,buf3中,buf1记录了不使用batch和prefetch方法的数据集内容,buf2记录了正常情况下使用batch和prefetch方法处理后的数据集内容,buf3记录了在确定性算法加持下,使用batch和prefetch方法处理后的数据集内容。我们将三次结果展开记录到下表中,Checkpoint的记录与读取以及程序的中断位置都在表中进行了标注。
| buf1-(无batch,无prefetch) | buf2-(batch,prefetch) | buf3-(determinism)(batch,prefetch) |
|---|---|---|
| (1, 0.0)*记录位置1 | (1, 0.0)*记录位置1 | (1, 0.0)*记录位置1 |
| (2, 1.0) | (2, 1.0) | (2, 1.0) |
| (3, 2.0) | (3, 2.0) | (3, 2.0) |
| (4, 3.0) | (4, 3.0) | (4, 3.0) |
| (5, 4.0) | (5, 4.0) | (5, 4.0) |
| (6, 5.0) | (6, 5.0) | (6, 5.0) |
| (7, 6.0) | (7, 6.0) | (7, 6.0) |
| (8, 7.0) | (8, 7.0) | (8, 7.0) |
| (9, 8.0) | (9, 8.0) | (9, 8.0) |
| (10, 9.0) | (10, 9.0) | (10, 9.0) |
| (11, 10.0) *记录位置2 | (11, 10.0) *记录位置2 | (11, 10.0) *记录位置2 |
| (12, 11.0) | (12, 11.0) | (12, 11.0) |
| (13, 12.0) *中断位置 | (13, 12.0) *中断位置 | (13, 12.0) *中断位置 |
| (12, 11.0) *读取位置 | (12, 13.0) *读取位置 | (12, 11.0) *读取位置 |
| (13, 12.0) | (13, 14.0) | (13, 12.0) |
| (14, 13.0) | (14, 15.0) | (14, 13.0) |
| (15, 14.0) | (15, 16.0) | (15, 14.0) |
| (16, 15.0) | (16, 17.0) | (16, 15.0) |
| (17, 16.0) | (17, 18.0) | (17, 16.0) |
| (18, 17.0) | (18, 19.0) | (18, 17.0) |
| (19, 18.0) | (19, 20.0) | (19, 18.0) |
| (20, 19.0) | (20, 21.0) | (20, 19.0) |
| (21, 20.0) *记录位置3 | (21, 22.0) *记录位置3 | (21, 20.0) *记录位置3 |
| (22, 21.0) | (22, 23.0) | (22, 21.0) |
| (23, 22.0) | (23, 24.0) | (23, 22.0) |
| (24, 23.0) | (24, 25.0) | (24, 23.0) |
| (25, 24.0) | (25, 26.0) | (25, 24.0) |
buf1是完全正确且符合我们预期的数据集内容,将其作为baseline, 可以发现,中断位置前的数据集内容在buf1,buf2和buf3中是一致的,但是从读取位置开始,buf2记录的数据集内容就与其他的不同了,这就是Dataset的一个bug,如果没有确定性算法,使用batch和prefetch方法会导致再次迭代时数据丢失,这在确定性算法(buf3)下是没有的。为了避免该bug, 我们必须以确定性算法展开实验,否则每次中断后重启实验,数据集内容会丢失一部分,为实验带来额外的变数。
P.S. buf2代表的bug是很极端的情况,只有我们先构建级联的dataset时才会出现,且和tf.config.experimental.enable_op_determinism()这一实验性内容有关,因此该bug短期被直接修复的可能性不大。
3. 方案与对比综合上述内容,这里给出两种切实可行的保存与回溯训练集状态的方案。
3.1 方案1 − − -- −−基于tf.data.Dataset.from_tensor_slices 构建数据集import tensorflow as tf
import tempfile
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
datas = {"test":[str(item) for item in range(100)]}
#稍微增加点复杂性 我们基于映射结构构建string形式的数据集,以便于后续更加深入的讨论
def mapfunc(x):
def py_func(inp):
return int(inp.numpy())
k = list(x.keys())
v = list(x.values())
y = tf.py_function(py_func,inp=v,Tout=[tf.float32])
return dict(zip(k,y))
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()
with tempfile.TemporaryDirectory() as dir_name:
dataset_A = tf.data.Dataset.from_tensor_slices(datas).shuffle(buffer_size=100,seed=0)
def wrapper(iterator):
def gen():
yield from iterator
dataset_B = tf.data.Dataset.from_generator(gen,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.string)})).map(map_func=mapfunc)
return dataset_B
step = tf.Variable(0)
iterator = iter(dataset_A)
checkpoint = tf.train.Checkpoint(iterator=iterator,step=step)
ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
ckpt_manager.restore_or_initialize()
for s,item in zip(range(step.numpy()+1,25+1),wrapper(iterator)):
step.assign(s)
# train_step()
ckpt_manager.save(check_interval=True,checkpoint_number=step)
其中shuffle() map() batch()和prefetch()可以按照个人需求增减,但若使用了batch()和prefetch(), 就必须以确定性算法tf.config.experimental.enable_op_determinism()为前提, 以保证数据集可以精确回溯到上一次保存的状态。如果有tf.py_function的使用需要,必须将数据集拆分成至少两部分,只能保存和回溯不需要tf.py_function的部分。shuffle()和map()的次序是不强制的,如果数据集被拆分成多部分,shuffle()应当仅存在于被保存的部分。
3.2 方案2 − − -- −−基于Generator和tf.data.Dataset.from_generator构建数据集import tempfile
import copy
import random
import logging
from typeguard import typechecked
import itertools
physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
def get_random_from_seed(seed:int|None=None):
return random.Random(seed) if seed is not None else None
def random_datas(datas:list,random:random.Random|None=None):
if random is not None:
random.shuffle(datas) # since datas has been shuffled, the next shuffle will not be the same
logging.info(f"Random complete!,the first data is {datas[0]}")
return datas
class DataIter():
@typechecked
def __init__(self,datas:list,counters:dict[Literal["step","epoch"],tf.Variable],seed:int|None=None) -> None:
# self.count = 0
self._epoch = counters["epoch"]
self._step = counters["step"]
self._datas = datas
self.seed = seed
self._check()
def _check(self):
assert (self.step//self.length)==self.epoch
@property
def epoch(self):
return self._epoch.numpy()
@property
def step(self):
return self._step.numpy()
@property
def datas(self):
return copy.deepcopy(self._datas) # do not influence original data
@property
def length(self):
return len(self.datas)
def __iter__(self):
datas = self.datas
random = get_random_from_seed(self.seed)
for _ in range(self.epoch+1):
datas = random_datas(datas,random)
# self.step%len(self.datas) do not need 'minus 1', because it is the start index exactly
return itertools.islice(datas,self.step%self.length,self.length)
def __repr__(self) -> str:
return f"Static indices is epoch:{self.epoch} step:{self.step}."
datas = [{"test":str(item)} for item in range(100)]
def mapfunc(x):
k = list(x.keys())
v = list(x.values())
y = list(map(lambda inp:int(inp),v))
return dict(zip(k,y))
with tempfile.TemporaryDirectory() as dir_name:
step = tf.Variable(0)
epoch = tf.Variable(0)
checkpoint = tf.train.Checkpoint(step=step,epoch=epoch)
ckpt_manager = tf.train.CheckpointManager(checkpoint=checkpoint,directory=dir_name,max_to_keep=3,step_counter=step,checkpoint_interval=10)
counters = {"step":step,"epoch":epoch}
di = DataIter(datas,counters=counters,seed=0)
def generator(): # can be overwrited
for item in di:
yield mapfunc(item)
dataset = tf.data.Dataset.from_generator(generator,output_signature=({"test":tf.TensorSpec(shape=[],dtype=tf.float32)}))
for e in range(epoch.numpy(),5+1):
epoch.assign(e)
for s,item in zip(range(step.numpy()+1,25+1),dataset):
step.assign(s)
ckpt_manager.save(check_interval=True,checkpoint_number=step)
首先构建DataIter(),受全局的epoch和step影响,让独立的random.Random()管理shuffle()功能,而不是由Dataset()管理, 由于random本身会因为每次shuffle()而改变状态,而该状态无法由Checkpoint保存,因此DataIter()只能在每次iter()时 推理出上一次的状态,这是目前的无奈之举。
3.3 方案对比方案1
- 优点:
- shuffle可以直接由tf.data.Dataset()管理,shuffle()状态可以由checkpoint保存和读取,不需要迭代推理
- 缺点:
- shuffle() seed不完全受控,由于TensorFlow的random行为由“全局种子+操作种子”同时控制,存在潜在的意外shuffle()行为, 一旦代码中设置全局种子的行为,就会影响dataset
- 不能直接支持tf.py_function,必须将数据集手动拆分组合,且以确定性算法为前提,才能实现数据集内容的精确回溯
- 使用时,构建数据集后,必须再构建基于数据集的迭代器,若数据集被拆分成级联的两部分,每次使用,还需要基于该迭代器再次构建数据集,不方便且容易出错
方案2
- 优点:
- generator()可以为纯python代码,如果出现方案1中必须要使用tf.py_function的地方,可以直接以纯python代码代替,极大地简化了这部分内容
- shuffle()行为可以由独立的random.Random()完全管理,完全受seed和shuffle()次数控制随机行为,不依赖于TensorFlow的random行为,因此不受框架限制
- 使用时比较方便,因为generator()中是对DataIter()实例的引用,Checkpoint读取时覆盖了全局step和epoch,相当于向DataIter()实例传递了回溯信息,因此不需要重新构建DataIter()实例
- 缺点:
- 必须自行构建DataIter()或者类似功能的类,以实现shuffle()功能,且无法保存random状态至Checkpoint中,每次iter()时需要从头开始推理状态,浪费了一部分计算资源
我们认为,方案1和方案2都是可行方案,或者说,都是存在硬性缺陷的无奈之举,二者集体摆烂,所以没有明显优劣之分,只能说,目前为了实现保存与回溯训练集状态的目的,方案1和方案2都是可行之举。若随着TensorFlow2.x版本更新,方案1和方案2中存在的问题得以解决,本文将持续更新。



