栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

计算机视觉(十二):Tensorflow常用功能模块

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

计算机视觉(十二):Tensorflow常用功能模块

计算机视觉笔记总目录
1 fit的callbacks详解

回调是在训练过程的给定阶段应用的一组函数。可以使用回调来获取培训期间内部状态和模型统计信息的视图。您可以将回调列表(作为关键字参数callbacks)传递给或类的 fit() 方法。然后将在训练的每个阶段调用回调的相关方法。

  • 定制化保存模型
  • 保存events文件
1.1 ModelCheckpoint

from tensorflow.python.keras.callbacks import ModelCheckpoint

ModelCheckpoint(filepath, monitor='val_loss', save_best_only=False, save_weights_only=False, mode='auto', period=1)

  • Save the model after every epoch:每隔多少次迭代保存模型
  • filepath:保存模型字符串
    • 如果设置 weights.{epoch:02d}-{val_loss:.2f}.hdf5格式,将会每隔epoch number数量并且将验证集的损失保存在该位置
    • 如果设置weights.{epoch:02d}-{val_acc:.2f}.hdf5,将会按照val_acc的值进行保存模型
  • monitor:quantity to monitor.设置为’val_acc’或者’val_loss’
  • save_best_only:if save_best_only=True, 只保留比上次模型更好的结果
  • save_weights_only:if True, 只保存去那种(model.save_weights(filepath)), else the full model is saved (model.save(filepath)).
  • mode:one of {auto, min, max}. 如果save_best_only=True, 对于val_acc, 要设置max, 对于val_loss要设置min
  • period: 迭代保存checkpoints的间隔
check = ModelCheckpoint('./ckpt/singlenn_{epoch:02d}-{val_acc:.2f}.h5',
                                monitor='val_acc',
                                save_best_only=True,
                                save_weights_only=True,
                                mode='auto',
                                period=1)

SingleNN.model.fit(self.train, self.train_label, epochs=5, callbacks=[check], validation_data=(x, y))

注意:使用ModelCheckpoint一定要在fit当中指定验证集才能使用,否则报错误。

1.2 Tensorboard

添加Tensorboard观察损失等情况

from tensorflow.python.keras.callbacks import TensorBoard

TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

  • log_dir:保存事件文件目录
  • write_graph=True:是否显示图结构
  • write_images=False:是否显示图片
  • write_grads=True:是否显示梯度histogram_freq 必须大于0
# 添加tensoboard观察
tensorboard = keras.callbacks.TensorBoard(log_dir='./graph', histogram_freq=1,
                                                  write_graph=True, write_images=True)

SingleNN.model.fit(self.train, self.train_label, epochs=5, callbacks=[tensorboard])

打开终端查看:

# 指定存在文件的目录,打开下面命令
tensoboard --logdir "./"

2 tf.data:数据集的构建与预处理
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/529335.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号