栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

nn.BatchNorm2d——批量标准化操作解读

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

nn.BatchNorm2d——批量标准化操作解读

Xnew​ (1−momentum)×Xold​ momentum×Xt​其中 Xnew​是模型的新参数 Xold​是模型原来的参数 Xt​是当前观测值的参数
②采用和训练阶段相同的计算方法 即只计算当前输入数据的均值和方差

输入

num_features 输入图像的通道数量。eps 稳定系数 防止分母出现0。momentum 模型均值和方差更新时的参数 见上述公式。affine 代表gamma beta是否可学。如果设为True 代表两个参数是通过学习得到的 如果设为False 代表两个参数是固定值 默认情况下 gamma是1 beta是0。track_running_stats 代表训练阶段是否更新模型存储的均值和方差 即测试阶段的均值与方差的计算方法采用第一种方法还是第二种方法。如果设为True 则代表训练阶段每次迭代都会更新模型存储的均值和方差(计算全局数据) 测试过程中利用存储的均值和方差对各个通道进行标准化处理 如果设为False 则模型不会存储均值和方差 训练过程中也不会更新均值和方差的数据 测试过程中只计算当前输入图像的均值和方差数据(局部数据)。具体区别见代码案例。

注意

训练阶段的标准化过程中 均值和方差来源途径只有一种方式 即利用当前输入的数据进行计算。测试阶段的标准化过程中 均值和方差来源途径有两种方式 一是来源于全局的数据 即模型本身存储一组均值和方差数据 在训练过程中 不断更新它们 使其具有描述全局数据的统计特性 二是来源于当前的输入数据 即和训练阶段计算方法一样 但这样会在测试过程中带来统计特性偏移的弊端 一般track_running_stats设置为True 即采用第一种来源途径。换句话说 就是训练阶段和测试阶段所承载的任务不同 训练阶段主要是通过已知的数据去优化模型 而测试阶段主要是利用已知的模型去预测未知的数据。

用途

训练过程中遇到收敛速度很慢的问题时 可以通过引入BN层来加快网络模型的收敛速度遇到梯度消失或者梯度爆炸的问题时 可以考虑引入BN层来解决一般情况下 还可以通过引入BN层来加快网络的训练速度

批量标准化的具体原理请参考论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

代码案例

一般用法

import torch
from torch import nn
# 在(0-1)范围内随机生成数据
img torch.rand(2,2,2,3)
bn nn.BatchNorm2d(2)
img_2 bn(img)
print(img)
print(img_2)

输出

# 标准化前
tensor([[[[0.5330, 0.7753, 0.6192],
 [0.9190, 0.1657, 0.5841]],
 [[0.7766, 0.7864, 0.2004],
 [0.9379, 0.3253, 0.1964]]],
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267934.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号