栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

TensorFlow保存以及恢复模型找到特定张量以及操作

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TensorFlow保存以及恢复模型找到特定张量以及操作

  1. 恢复模型和获取特定张量:  
  2. 从pb文件中恢复模型:  
  3.   
    with tf.gfile.GFile('./frozen.pb','rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
        tf.import_graph_def(graph_def, name='')  
  4. 从ckpt和meta文件中恢复,ckpt文件保存的是模型的张量数据,meta保存的是图的架构,先导入图,然后导入各个张量的值。  
  5. meta_file, ckpt_file = 'model.meta','model.ckpt'  
    saver = tf.train.import_meta_graph(meta_file, input_map=input_map)  
    saver.restore(tf.get_default_session(), ckpt_file)  
  6. 找到特定张量,使用函数graph.get_tensor_by_name,得到的是个张量,可以直接作为Session.run(graph.get_tensor_by_name(' ')),如果张量需要输入,还需要有feed_dict。取graph的部分子图,可以用来feed相应的张量,然后输出想要的值。  
  7.   
    embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")  
    print(embeddings)  
  8. 输出:Tensor("embeddings:0", shape=(?, 512), dtype=float32)  
  9.   
  10. 找到特定操作Operation,使用函数graph.get_operation_by_name,得到的是操作节点,字典形式,其中包括输入和输出张量  
  11.   
    input_operation = graph.get_operation_by_name("import/Mul")  
    output_operation = graph.get_operation_by_name("import/final_result")  
    print(input_operation)  
    print(output_operation)  
  12. 输出:   
  13.   
  14. name: "import/Mul"  
  15. op: "Mul"  
  16. input: "import/Sub"  
  17. input: "import/Mul/y"  
  18. attr {  
  19.   key: "T"  
  20.   value {  
  21.     type: DT_FLOAT  
  22.   }  
  23. }  
  24.   
  25.   
  26. name: "import/final_result"  
  27. op: "Softmax"  
  28. input: "import/final_training_ops/Wx_plus_b/add"  
  29. attr {  
  30.   key: "T"  
  31.   value {  
  32.     type: DT_FLOAT  
  33.   }  
  34. }  
  35.   
  36.  可以通过字典找到相应的输出  
  37.   
    print(input_operation.name)  
    print(output_operation.name)  
    print(output_operation.inputs[0])  
    print(input_operation.outputs)  

  38. 输出:   
  39.   
  40. import/Mul  
  41. import/final_result  
  42. Tensor("import/final_training_ops/Wx_plus_b/add:0", shape=(?, 9), dtype=float32)  
  43. []  
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/299864.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号