栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

多标签分类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

多标签分类

https://www.jianshu.com/p/ac3bec3dde3e

多标签分类任务损失函数
在二分类、多分类任务中通常使用交叉熵损失函数,即Pytorch中的CrossEntorpy,但是在多标签分类任务中使用的是BCEWithLogitsLoss函数。

BCEWithLogitsLoss与CrossEntorpy的不同之处在于计算样本所属类别概率值时使用的计算函数不同:
1)CrossEntorpy使用softmax函数,即将模型输出值作为softmax函数的输入,进而计算样本属于每个类别的概率,softmax计算得到的类别概率值加和为1。
2)BCEWithLogitsLoss使用sigmoid函数,将模型输出值作为sigmoid函数的输入,计算得到的多个类别概率值加和不一定为1。

共同点是计算概率值后都继续计算预测概率值和真实标签之间的交叉熵作为最终的损失函数值。

为什么在多标签任务中使用BCEWithLogitsLoss(sigmoid)函数呢?个人理解如下:
1)二分类/多分类任务是在两个/多个类别中取出一个类别,并且各个类别之间是互斥的,因此要保证多个类别的概率值加和为1(在类别概率值加和为1的情况下,一个类别概率值增加时必然有其他类别概率值减小,体现了各个类别之间的互斥),并且最终取出概率值最大的类别。
2)多分类任务是在多个类别中取出一个或多个类别,各个类别之间不互斥,因此无需保证各类别概率加和为1,只需要计算样本属于每一个类别的概率,如果样本属于某一类别的概率高于阈值则代表样本属于该类别(此时类别概率值加和不一定为1),例如样本A经过BCEWithLogitsLoss函数计算后得到属于类别1、类别2、类别3和类别4的概率值分别为[0.6,0.7,0.3,0.4],阈值为0.5,则样本A同时属于类别1和类别2。

使用方法如下:

import torch
import torch.nn as nn

#创建输入
input = torch.tensor([[0.1,0.2,0.3],[0.4,0.5,0.6]])#共有两个输入样本
target = torch.tensor([[1,0,1],[0,1,1]])#每个样本的标签值都是多标签

#创建模型并计算
model = nn.Linear(input_dim = 3,hidden_dim=5)#此处随便写一个模型示意
model_out = model(input)

#计算损失函数值
loss = torch.nn.BCEWithLogitsLoss(model_out,target)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/665149.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号