focal loss原理:
控制正负样本权重
控制难易分类样本的权重
公式说明:
y就是实际标签
p就是预测值
CE(p,y)就是交叉熵
参数说明:
α就是你加的参数,也就是说,如果你把α设成0-0.5之间,你能够看到,其实是缩小了正样本的权重的,模型会重点去关注负样本
α如果是0.5-1之间,那也就意味着你增加了正样本的权重,模型会重点关注正样本
怎么设置:
正样本少,负样本多,α就在0.5-1之间设
控制难易分类样本的权重:
γ:调制因子
当pt的预测值比较低的时候,证明是模型难以识别这个样本,这时候(1-pt)比较大,通过γ就可以相对的增大权重
当pt的预测值比较高的时候,证明是模型容易识别这个样本,这时候(1-pt)比较小,通过γ进行γ次方,就可以相对的缩小权重
其实这里的增大和缩小都是相对来说的。因为(1-pt)本身就是小于1的,而γ大于1,所以(1-pt)的γ次方肯定是比原来的值小,但是难易分类的样本之间的差距其实是可以通过γ次方增大的。
所以整体的Focalloss就变成了下面这样。
pytorch代码实现:
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=.75, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
pos_weight = torch.FloatTensor([10.0]).to(device)
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, pos_weight=pos_weight, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
return F_loss.mean()



