栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

使用自定义图层保存Keras模型

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用自定义图层保存Keras模型

修正数1是使用

Custom_Objects
同时
loading
Saved Model
,即,更换代码,

new_model = tf.keras.models.load_model('model.h5')

new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})

由于我们使用

Custom Layers
build
Model
之前
Saving
的话,我们应该使用
CustomObjects
,同时
Loading
它。

更正数字2是在自定义图层

**kwargs
__init__
功能中添加

def __init__(self, k, name=None, **kwargs):        super(CustomLayer, self).__init__(name=name)        self.k = k        super(CustomLayer, self).__init__(**kwargs)

完整的工作代码如下所示:

import tensorflow as tfclass CustomLayer(tf.keras.layers.Layer):    def __init__(self, k, name=None, **kwargs):        super(CustomLayer, self).__init__(name=name)        self.k = k        super(CustomLayer, self).__init__(**kwargs)    def get_config(self):        config = super(CustomLayer, self).get_config()        config.update({"k": self.k})        return config    def call(self, input):        return tf.multiply(input, 2)model = tf.keras.models.Sequential([    tf.keras.Input(name='input_layer', shape=(10,)),    CustomLayer(10, name='custom_layer'),    tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')])tf.keras.models.save_model(model, 'model.h5')new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})print(new_model.summary())

上面代码的输出如下所示:

WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.Model: "sequential_1"_________________________________________________________________Layer (type)      Output Shape   Param #   =================================================================custom_layer_1 (CustomLayer) (None, 10)     0         _________________________________________________________________output_layer (Dense)         (None, 1)      11        =================================================================Total params: 11Trainable params: 11Non-trainable params: 0

希望这可以帮助。学习愉快!



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/659851.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号