Pytorch默认的交叉熵函数使用loss=(pred=浮点数, target=整数)的形式
# Example of target with class indices loss = nn.CrossEntropyLoss() input = torch.randn(3, 5, requires_grad=True) target = torch.empty(3, dtype=torch.long).random_(5) output = loss(input, target) output.backward() # Example of target with class probabilities input = torch.randn(3, 5, requires_grad=True) target = torch.randn(3, 5).softmax(dim=1) output = loss(input, target) output.backward()
但是当target需要为负点数的时候,没法使用loss = nn.CrossEntropyLoss()直接计算, 此处修改损失函数
也可参考标签平滑损失 [3]
def cross_entropy(pred, soft_targets):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))
Reference:
- https://discuss.pytorch.org/t/how-should-i-implement-cross-entropy-loss-with-continuous-target-outputs/10720/18https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.htmlhttps://blog.csdn.net/weixin_39529413/article/details/123122330



