分享好友 人工智能首页 频道列表

Focal Loss 的Pytorch 实现以及实验

pytorch教程  2023-03-08 14:037770

Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。

如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有相对较小的loss回应。这样就是对简单样本的一种decay。其中alpha 是对每个类别在训练数据中的频率有关, 但是下面的实现我们是基于alpha=1进行实验的。

Focal Loss 的Pytorch 实现以及实验

标准的Cross Entropy 为:

Focal Loss 的Pytorch 实现以及实验

Focal Loss 为:

Focal Loss 的Pytorch 实现以及实验

Focal Loss 的Pytorch 实现以及实验

其中 Focal Loss 的Pytorch 实现以及实验

以上公式为下面实现代码的基础。

 

采用基于pytorch 的yolo2 在VOC的上的实验结果如下:

 

Focal Loss 的Pytorch 实现以及实验

在单纯的替换了CrossEntropyLoss之后就有1个点左右的提升。效果还是比较显著的。本实验中采用的是darknet19, 要是采用更大的网络就可能会有更好的性能提升。这个实验结果已经能很好的说明的Focal Loss 的对于检测的价值了。

 

一点没做的但是可能会提升性能:

1. 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升

 

 

本文实验中采用的Focal Loss 代码如下。

关于Focal Loss 的数学推倒在文章:Focal Loss 的前向与后向公式推导

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 

查看更多关于【pytorch教程】的文章

展开全文
相关推荐
反对 0
举报 0
图文资讯
热门推荐
优选好物
更多热点专题
更多推荐文章
Pytorch-基础入门之ANN pytorch零基础入门
在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。第一部分:构造模型# Import Librariesimport torchimport torch.nn as nnfrom torch.autograd import Variable# Create ANN Modelclas

0评论2023-03-08379

解说pytorch中的model=model.to(device) pytorch基础教程
这篇文章主要介绍了pytorch中的model=model.to(device)使用说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教这代表将模型加载到指定设备上。其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("c

0评论2023-02-09935

Faster-RCNN Pytorch实现的minibatch包装
实际上faster-rcnn对于输入的图片是有resize操作的,在resize的图片基础上提取feature map,而后generate一定数量的RoI。我想首先去掉这个resize的操作,对每张图都是在原始图片基础上进行识别,所以要找到它到底在哪里resize了图片。直接搜 grep 'resize' ./

0评论2023-02-09876

pytorch Gradient Clipping
梯度裁剪(Gradient Clipping)import torch.nn as nnoutputs = model(data)loss= loss_fn(outputs, target)optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)optimizer.step()nn.utils.clip_gra

0评论2023-02-09654

更多推荐