focal loss论文笔记(附基于keras的多类别focal loss代码)

时间:2024-03-13 18:35:49

一.focal loss论文

二.focal loss提出的目的

  • 解决one-stage目标检测是场景下前景和背景极度不平衡的情况(1:1000)
  • 让模型在训练的时候更加关注hard examples(前景)。
  • 另外two-stage的检测器是用一下两个方法来解决类别不平衡问题的:
    • 提取候选框的过程实际上就消除了很多背景框,因为提取的候选框是大概率包括目标的
    • 在第二个阶段训练的时候,minibatch一般被认为的固定正负样本的比例,大致是1:3

三.focal loss原理

1.CE(cross entropy) 交叉熵

focal loss论文笔记(附基于keras的多类别focal loss代码)
focal loss论文笔记(附基于keras的多类别focal loss代码)

2.balanced CE

focal loss论文笔记(附基于keras的多类别focal loss代码)

3.focal loss

focal loss论文笔记(附基于keras的多类别focal loss代码)
- 当一个样本误分类后,pt接近0,1-pt接近1,则loss无影响;当pt接近与1,则1-pt接近与0,loss的权重变的很小,则该样本的loss对总的loss贡献就小了。
- lamda是一个超参数,反应了权值系数的影响程度,在作者的实验中lamda=2的结果是最好的。

4.focal loss with balanced weight

focal loss论文笔记(附基于keras的多类别focal loss代码)

5.RetinaNet模型框架部分

focal loss论文笔记(附基于keras的多类别focal loss代码)
- backbone部分采用基于resnet的FPN结构,P3到P7一共5层的金字塔结构
- anchor部分,对于密集的目标场景增加了更多尺度的anchor
- 分类子网络部分,每个金字塔level的子网络参数是共享的,一个子网络包含4个conv层,卷积核大小为3*3;分类子网络和回归子网络的结构相同,但是参数是分开的,不像RPN里面是共享的。
- 网络的初始化,在分类子网络的最后一层卷积部分,偏执初始化为-log((1-pi)/pi);其他的偏执初始化为0,权值初始化为高斯权值,delta取0.01

四.focal loss代码

  • 论文中理论叙述的场景是基于二分类,以下为基于多分类的focal loss,同时参考一篇博客,加入了另外的因子能够更好的防止过拟合
  • 代码见maozezhong/focal_loss_multi_class