栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Focal loss

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Focal loss

focal loss原理:
控制正负样本权重
控制难易分类样本的权重
公式说明:
y就是实际标签
p就是预测值
CE(p,y)就是交叉熵

参数说明:
α就是你加的参数,也就是说,如果你把α设成0-0.5之间,你能够看到,其实是缩小了正样本的权重的,模型会重点去关注负样本
α如果是0.5-1之间,那也就意味着你增加了正样本的权重,模型会重点关注正样本


怎么设置:
正样本少,负样本多,α就在0.5-1之间设

控制难易分类样本的权重:
γ:调制因子
当pt的预测值比较低的时候,证明是模型难以识别这个样本,这时候(1-pt)比较大,通过γ就可以相对的增大权重
当pt的预测值比较高的时候,证明是模型容易识别这个样本,这时候(1-pt)比较小,通过γ进行γ次方,就可以相对的缩小权重
其实这里的增大和缩小都是相对来说的。因为(1-pt)本身就是小于1的,而γ大于1,所以(1-pt)的γ次方肯定是比原来的值小,但是难易分类的样本之间的差距其实是可以通过γ次方增大的。

所以整体的Focalloss就变成了下面这样。

pytorch代码实现:

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=.75, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        pos_weight = torch.FloatTensor([10.0]).to(device)
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, pos_weight=pos_weight, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580958.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号