栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

损失函数总结

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

损失函数总结

cross entropy loss
具体交叉熵的理论网络上有很多,这里就看下pytorch内部的计算方式

import torch
import torch.nn as nn

# 网络的输出数据,数据取自深度之眼pytorch入门lesson-15人民币二分类
# BATCH改为5,CLASS = 2
out = torch.tensor(
    [[-0.0223, 0.2420],
     [0.1782, 0.6221],
     [0.0887, 0.4575],
     [0.3041, 0.3169],
     [0.1052, 0.0649]])
# 数据标签
label = torch.tensor([1, 0, 1, 0, 1])
print(f"网络输出out = {out}")

print("0 - 比较手动log, softmax和LogSoftmax")
sm = nn.Softmax(dim=1)
print(f"经过Softmax = {sm(out)}")
print(f"再做Log = {torch.log(sm(out))}")
lsm = nn.LogSoftmax(dim=1)
lsm_result = lsm(out)
print(f"直接使用LogSoftmax = {lsm_result}")
print("n")

print("1 - 使用NLLLoss")
loss = nn.NLLLoss()
print(f"NLLLoss = {loss(lsm_result, label)}")
print("n")

print("2 - 手工计算loss")
ce = [lsm_result[index,i.item()].item() for index,i in enumerate(label)]
print(ce)
ce_tensor = torch.tensor(ce)
print(ce_tensor)
print(f"手工计算Loss = {-ce_tensor.mean()}")
print("n")

print("3 - 使用CrossEntropyLoss")
loss = nn.CrossEntropyLoss()
print(f"使用CrossEntropyLoss = {loss(out, label)}")
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/754897.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号