栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

如何将Keras .h5导出到tensorflow .pb?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

如何将Keras .h5导出到tensorflow .pb?

Keras本身不包括将TensorFlow图导出为协议缓冲区文件的任何方法,但是您可以使用常规TensorFlow实用程序来实现。这是一篇博客文章,解释了如何使用

freeze_graph.py
TensorFlow中包含的实用程序脚本执行此操作,这是完成操作的“典型”方式。

但是,我个人觉得必须创建一个检查点,然后运行一个外部脚本来获取模型,但我更喜欢从我自己的Python代码中执行此操作,因此我使用了这样的函数:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):    """    Freezes the state of a session into a pruned computation graph.    Creates a new computation graph where variable nodes are replaced by    constants taking their current value in the session. The new graph will be    pruned so subgraphs that are not necessary to compute the requested    outputs are removed.    @param session The TensorFlow session to be frozen.    @param keep_var_names A list of variable names that should not be frozen,    or None to freeze all the variables in the graph.    @param output_names Names of the relevant graph outputs.    @param clear_devices Remove the device directives from the graph for better portability.    @return The frozen graph definition.    """    graph = session.graph    with graph.as_default():        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))        output_names = output_names or []        output_names += [v.op.name for v in tf.global_variables()]        input_graph_def = graph.as_graph_def()        if clear_devices: for node in input_graph_def.node:     node.device = ""        frozen_graph = tf.graph_util.convert_variables_to_constants( session, input_graph_def, output_names, freeze_var_names)        return frozen_graph

这是实施的启发

freeze_graph.py
。参数也类似于脚本。
session
是TensorFlow会话对象。
keep_var_names
仅在您希望不冻结某些变量时才需要(例如,对于有状态模型),通常不需要。
output_names
是包含产生所需输出的操作名称的列表。
clear_devices
只需删除任何设备指令即可使图形更具可移植性。因此,对于
model
具有一个输出的典型Keras
,您将执行以下操作:

from keras import backend as K# Create, compile and train model...frozen_graph = freeze_session(K.get_session(),        output_names=[out.op.name for out in model.outputs])

然后,您可以像往常一样使用

tf.train.write_graph
以下命令将图形写入文件:

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/647817.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号