栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

TF2.0中的saved_model.prune()

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TF2.0中的saved_model.prune()

看来您在第1版中修剪模型的方式很好;根据您的错误消息,无法保存生成的修剪模型,因为它不是“可跟踪的”,这是使用保存模型的必要条件

tf.saved_model.save
。生成可跟踪对象的一种方法是从
tf.Module
类继承,如使用SavedModel格式和具体函数的指南中所述。下面是一个示例,尝试保存一个
tf.function
对象(由于该对象不可跟踪而失败),从继承
tf.module
并保存生成的对象:

(使用Python版本3.7.6,TensorFlow版本2.1.0和NumPy版本1.18.1)

import tensorflow as tf, numpy as np# Define a random TensorFlow function and generate a reference outputconv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)@tf.functiondef conv_model(x):    return tf.nn.conv2d(x, conv_filter, 1, "SAME")input_tensor = tf.ones([1, 2, 3, 4])output_tensor = conv_model(input_tensor)print("Original model outputs:", output_tensor, sep="n")# Try saving the model: it won't work because a tf.function is not trackableexport_dir = "./tmp/"try: tf.saved_model.save(conv_model, export_dir)except ValueError: print(    "Can't save {} object because it's not trackable".format(type(conv_model)))# Now define a trackable object by inheriting from the tf.Module classclass MyModule(tf.Module):    @tf.function    def __call__(self, x): return conv_model(x)# Instantiate the trackable object, and call once to trace-compile a graphmodule_func = MyModule()module_func(input_tensor)tf.saved_model.save(module_func, export_dir)# Restore the model and verify that the outputs are consistentrestored_model = tf.saved_model.load(export_dir)restored_output_tensor = restored_model(input_tensor)print("Restored model outputs:", restored_output_tensor, sep="n")if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):    print("Outputs are consistent :)")else: print("Outputs are NOT consistent :(")

控制台输出:

Original model outputs:tf.Tensor([[[[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]  [[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Can't save <class 'tensorflow.python.eager.def_function.Function'> objectbecause it's not trackableRestored model outputs:tf.Tensor([[[[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]  [[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Outputs are consistent :)

因此,您应该尝试按以下方式修改代码:

svmod = tf.saved_model.load(fn) #version 1svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])class Exportable(tf.Module):    @tf.function    def __call__(self, model_inputs): return svmod2(model_inputs)svmod2_export = Exportable()svmod2_export(typical_input)    # call once with typical input to trace-compiletf.saved_model.save(svmod2_export, '/tmp/saved_model/')

如果您不想继承自

tf.Module
,则可以替换实例代码,实例化一个
tf.Module
对象并添加
tf.function
方法/可调用属性,如下所示:

to_export = tf.Module()to_export.call = tf.function(conv_model)to_export.call(input_tensor)tf.saved_model.save(to_export, export_dir)restored_module = tf.saved_model.load(export_dir)restored_func = restored_module.call


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/646268.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号