栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch.nn.KLDivLoss()损失函数输出结果为负数

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch.nn.KLDivLoss()损失函数输出结果为负数


蘆 Author :Horizon Max

✨ 编程技巧篇:各种操作小结

 机器视觉篇:会变魔术 OpenCV

 机器学习篇:简单入门 PyTorch

 神经网络篇:经典网络模型

 算法篇:再忙也别忘了 LeetCode


import torch

a = torch.ones((3, 3, 8, 8), dtype=torch.float32) * 0.5
b = a + 0.01

criterion = torch.nn.KLDivLoss()
loss = criterion(a, b)

print(loss)

输出结果:

tensor(-0.5984)

以及出现 UserWarning :

UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.
  "reduction: 'mean' divides the total loss by both the batch size and the support size."

解决方案参考博客:UserWarning: reduction: ‘mean‘ divides the total loss by both the batch size and the support size.

出现负数的原因参考:https://discuss.pytorch.org/

因为在计算损失时把概率分布映射到 log 空间,所以给输入添加 log 即可解决:

import torch

a = torch.ones((3, 3, 8, 8), dtype=torch.float32) * 0.5
b = a + 0.01

criterion = torch.nn.KLDivLoss()
loss = criterion(a.log(), b)

print(loss)

输出结果:

tensor(0.0101)


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/829269.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号