栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

损失函数(交叉熵误差)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

损失函数(交叉熵误差)

损失函数

神经网络以某个指标为线索寻找最优权重参数。神经网络的学习中所用的指标称为损失函数 (loss function)。这个损失函数可以使用任意函数,但一般用均方误差和交叉熵误差等。

交叉熵误差

除了均方误差之外,交叉熵误差 (cross entropy error)也经常被用作损失函数。交叉熵误差如下式所示。

这里,log 表示以e为底数的自然对数(loge )。yk是神经网络的输出,tk 是正确解标签。并且,tk 中只有正确解标签的索引为 1,其他均为 0(one-hot 表示)。因此,式(4.2)实际上只计算对应正确解标签的输出的自然对数。

自然对数的图像如图 4-3 所示。

如图 4-3 所示,x 等于 1 时,y 为 0;随着 x 向 0 靠近,y 逐渐变小。所以,正确解标签对应的输出越大,式(4.2)的值越接近 0;

当输出为 1 时,交叉熵误差为 0。此外,如果正确解标签对应的输出较小,则式(4.2)的值较大。

代码实现
import numpy as np


def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))


if __name__ == '__main__':
    t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
    y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
    print(cross_entropy_error(np.array(y), np.array(t)))
    y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
    print(cross_entropy_error(np.array(y), np.array(t)))

运行结果:

0.510825457099338
2.302584092994546

这里,参数 y 和 t 是 NumPy 数组。函数内部在计算 np.log 时,加上了一个微小值 delta 。这是因为,当出现np.log(0) 时,np.log(0) 会变为负无限大的 -inf,这样一来就会导致后续计算无法进行。作为保护性对策,添加一个微小值可以防止负无限大的发生。

正确解标签的索引是“2”,与之对应的神经网络的输出是 0.6,则交叉熵误差是 -log 0.6 = 0.51;若“2”对应的输出是 0.1,则交叉熵误差为 -log 0.1 = 2.30。也就是说,交叉熵误差的值是由正确解标签所对应的输出结果决定的。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/846426.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号