栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

同时运行多个经过预训练的Tensorflow网络

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

同时运行多个经过预训练的Tensorflow网络

最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:

# Build a graph containing `net1`.with tf.Graph().as_default() as net1_graph:  net1 = CreateAlexNet()  saver1 = tf.train.Saver(...)sess1 = tf.Session(graph=net1_graph)saver1.restore(sess1, 'epoch_10.ckpt')# Build a separate graph containing `net2`.with tf.Graph().as_default() as net2_graph:  net2 = CreateAlexNet()  saver2 = tf.train.Saver(...)sess2 = tf.Session(graph=net1_graph)saver2.restore(sess2, 'epoch_50.ckpt')

如果由于某种原因该方法不起作用,并且您必须使用一个

tf.Session
(例如,因为您希望将来自两个网络的结果合并到另一个TensorFlow计算中),最好的解决方案是:

  1. 像已经做的那样在名称范围中创建不同的网络,并且
  2. tf.train.Saver
    为两个网络创建单独的实例,并带有一个附加参数以重新映射变量名称。

当构建的储户,就可以通过一本字典作为

var_list
参数,在检查点映射变量的名称(即没有名称范围前缀)给
tf.Variable
你的每个模型创建的对象。

您可以以

var_list
编程方式进行构建,并且应该能够执行以下操作:

with tf.name_scope("net1"):  net1 = CreateAlexNet()with tf.name_scope("net2"):  net2 = CreateAlexNet()# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.net1_varlist = {v.name.lstrip("net1/"): v     for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}net1_saver = tf.train.Saver(var_list=net1_varlist)# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.net2_varlist = {v.name.lstrip("net2/"): v     for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}net2_saver = tf.train.Saver(var_list=net2_varlist)# ...net1_saver.restore(sess, "epoch_10.ckpt")net2_saver.restore(sess, "epoch_50.ckpt")


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/653092.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号