栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

TensorFlow 1.10+自定义估算器通过train_and_evaluate提前停止

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TensorFlow 1.10+自定义估算器通过train_and_evaluate提前停止

我现在知道您的困惑。

stop_if_no_decrease_hook
状态文档(重点是我的):

max_steps_without_decrease:int,给定指标不降低的最大 训练步骤 数。

eval_dir:如果设置,则目录包含带有评估指标的摘要文件。默认情况下,将使用estimator.eval_dir()。

通过查看钩子(1.11版)的代码,您会发现:

def stop_if_no_metric_improvement_fn():    """Returns `True` if metric does not improve within max steps."""    eval_results = read_eval_metrics(eval_dir) #<<<<<<<<<<<<<<<<<<<<<<<    best_val = None    best_val_step = None    for step, metrics in eval_results.items(): #<<<<<<<<<<<<<<<<<<<<<<<      if step < min_steps:        continue      val = metrics[metric_name]      if best_val is None or is_lhs_better(val, best_val):        best_val = val        best_val_step = step      if step - best_val_step >= max_steps_without_improvement: #<<<<<        tf_logging.info( 'No %s in metric "%s" for %s steps, which is greater than or equal ' 'to max steps (%s) configured for early stopping.', increase_or_decrease, metric_name, step - best_val_step, max_steps_without_improvement)        return True    return False

该代码的作用是加载评估结果(随您的

evalSpec
参数生成),并提取评估结果以及
global_step
与特定评估记录相关的(或您用来计数的其他任何自定义步骤)。

这是

trainingsteps
文档部分的来源:提前停止不是根据未改进评估的次数触发的,而是根据特定步长范围内未改进评估的次数触发的(恕我直言,这有点违反直觉)。

因此,回顾一下: 是的 ,提前停止挂钩使用评估结果来决定何时削减培训, 但是
您需要输入要监视的培训步骤的数量,并记住要进行多少次评估在这个步骤中。

带有数字的示例,希望可以进一步阐明

假设您正在无限期地训练,每1k步进行一次评估。只要评估每隔1k步运行一次即可生成我们要监控的指标,评估运行方式的具体细节就无关紧要。

如果将挂钩设置为挂钩,

hook = tf.contrib.estimator.stop_if_no_decrease_hook(my_estimator,'my_metric_to_monitor', 10000)
挂钩将考虑在10k步的范围内进行评估。

由于您每1k步运行1个评估,因此如果连续10次评估没有任何改善,则可以归结为提前停止。如果这样您决定每2k步使用一次评估进行重新运行,则该挂钩将仅考虑5个连续评估的序列,而无需进行任何改进。

保持最佳模式

首先,重要的注意事项: 这与早期停止 ,通过训练保留最佳模型的副本以及一旦性能开始下降而停止训练的问题完全无关。

保持最佳模型非常容易

tf.estimator.BestExporter
,您可以在中定义一个
evalSpec
(摘录自链接):

  serving_input_receiver_fn = ... # define your serving_input_receiver_fn  exporter = tf.estimator.BestExporter(      name="best_exporter",      serving_input_receiver_fn=serving_input_receiver_fn,      exports_to_keep=5) # this will keep the 5 best checkpoints  eval_spec = [tf.estimator.evalSpec(    input_fn=eval_input_fn,    steps=100,    exporters=exporter,    start_delay_secs=0,    throttle_secs=5)]

如果您不知道如何定义

serving_input_fn

请看这里

这使您可以保留获得的总体最佳5个模型,并以

SavedModel
s形式存储(这是目前存储模型的首选方法)。



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/623572.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号