栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

CrossEntropyLoss改进

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

CrossEntropyLoss改进

文章目录
  • 前言
  • 一、CrossEntropyLoss
  • 二、SmoothCrossEntropy
  • 三、Sparse Softmax


前言

CrossEntropyLoss 是分类任务中经常使用的损失函数,但是在某些情况下,其优化效果并不是很好,本文介绍了最近出现的对CrossEntropyLoss进行改进的新损失函数

一、CrossEntropyLoss

公式:

上图是pytorch版实现的CrossEntropyLoss,可以看出其主要作用是优化了正例对应的logits(logits介绍见上一篇博文)并使其无限大与其他类别的logits,这种过强的要求可能使得模型难以训练至收敛,因而出现了LabelSM版本的CrossEntropy,以及Sparse Softmax

顺带提一句,pytorch版本的CrossEntropyLoss是对dim=1进行的计算,
因而我们需要把各个类别的logits放到dim=1上来
二、SmoothCrossEntropy

公式:
SmoothCrossEntropy对应的公式为:

优势:
当 label smoothing 的 loss 函数为 cross entropy 时,如果 loss 取得极值点,则正确类和错误类的 logit 会保持一个常数距离,且正确类和所有错误类的 logits 相差的常数是一样的,都是 log ⁡ ( K − ( K − 1 ) α α ) log(frac{K-(K-1)alpha}{alpha}) log(αK−(K−1)α​)
证明见:知乎

code:

class SmoothCrossEntropy(nn.Module):
    """
    loss = SmoothCrossEntropy()
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.empty(3, dtype=torch.long).random_(5)
    output = loss(input, target)
    """
    def __init__(self, alpha=0.1):
        super(SmoothCrossEntropy, self).__init__()
        self.alpha = alpha

    def forward(self, logits, labels):
        num_classes = logits.shape[-1]
        alpha_div_k = self.alpha / num_classes
        target_probs = F.one_hot(labels, num_classes=num_classes).float() * 
            (1. - self.alpha) + alpha_div_k
        loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)
        return loss.mean()

代码如下(示例):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import  ssl
ssl._create_default_https_context = ssl._create_unverified_context
三、Sparse Softmax

公式:

优势:
这是苏神在CAIL2020中提出的一个类别数过多的预测问题损失函数,我们只需要优化前topK项,使得 s t s_t st​ 大于topk即可,不必要大于最小的 log ⁡ ( n − 1 ) log(n-1) log(n−1)
,只需大于topk中最小的 l o g ( k ) log(k) log(k)即可,可以防止过度训练
证明
pytoch版本:

def Sparse_Softmax(predictions, token_type_id, input_ids, vocab_size):

    predictions = predictions[:, :-1].contiguous()
    target_mask = token_type_id[:, 1:].contiguous()
    """
       target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
    """
    predictions = predictions.view(-1, vocab_size)
    labels = input_ids[:, 1:].contiguous()
    labels = labels.view(-1)
    target_mask = target_mask.view(-1).float()
    # 正loss
    pos_loss = predictions[list(range(predictions.shape[0])), labels]
    # 负loss
    y_pred = torch.topk(predictions, k=args.k_sparse)[0]
    neg_loss = torch.logsumexp(y_pred, dim=-1)

    loss = neg_loss - pos_loss
    return (loss * target_mask).sum() / target_mask.sum()  ## 通过mask 取消 pad 和句子a部分预测的影响

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/290740.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号