我就废话不多说了,直接上代码吧!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# 支持多分类和二分类
class FocalLoss(nn.Module):
"""
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__( self , num_class, alpha = None , gamma = 2 , balance_index = - 1 , smooth = None , size_average = True ):
super (FocalLoss, self ).__init__()
self .num_class = num_class
self .alpha = alpha
self .gamma = gamma
self .smooth = smooth
self .size_average = size_average
if self .alpha is None :
self .alpha = torch.ones( self .num_class, 1 )
elif isinstance ( self .alpha, ( list , np.ndarray)):
assert len ( self .alpha) = = self .num_class
self .alpha = torch.FloatTensor(alpha).view( self .num_class, 1 )
self .alpha = self .alpha / self .alpha. sum ()
elif isinstance ( self .alpha, float ):
alpha = torch.ones( self .num_class, 1 )
alpha = alpha * ( 1 - self .alpha)
alpha[balance_index] = self .alpha
self .alpha = alpha
else :
raise TypeError( 'Not support alpha type' )
if self .smooth is not None :
if self .smooth < 0 or self .smooth > 1.0 :
raise ValueError( 'smooth value should be in [0,1]' )
def forward( self , input , target):
logit = F.softmax( input , dim = 1 )
if logit.dim() > 2 :
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size( 0 ), logit.size( 1 ), - 1 )
logit = logit.permute( 0 , 2 , 1 ).contiguous()
logit = logit.view( - 1 , logit.size( - 1 ))
target = target.view( - 1 , 1 )
# N = input.size(0)
# alpha = torch.ones(N, self.num_class)
# alpha = alpha * (1 - self.alpha)
# alpha = alpha.scatter_(1, target.long(), self.alpha)
epsilon = 1e - 10
alpha = self .alpha
if alpha.device ! = input .device:
alpha = alpha.to( input .device)
idx = target.cpu(). long ()
one_hot_key = torch.FloatTensor(target.size( 0 ), self .num_class).zero_()
one_hot_key = one_hot_key.scatter_( 1 , idx, 1 )
if one_hot_key.device ! = logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self .smooth:
one_hot_key = torch.clamp(
one_hot_key, self .smooth, 1.0 - self .smooth)
pt = (one_hot_key * logit). sum ( 1 ) + epsilon
logpt = pt.log()
gamma = self .gamma
alpha = alpha[idx]
loss = - 1 * alpha * torch. pow (( 1 - pt), gamma) * logpt
if self .size_average:
loss = loss.mean()
else :
loss = loss. sum ()
return loss
class BCEFocalLoss(torch.nn.Module):
"""
二分类的Focalloss alpha 固定
"""
def __init__( self , gamma = 2 , alpha = 0.25 , reduction = 'elementwise_mean' ):
super ().__init__()
self .gamma = gamma
self .alpha = alpha
self .reduction = reduction
def forward( self , _input, target):
pt = torch.sigmoid(_input)
alpha = self .alpha
loss = - alpha * ( 1 - pt) * * self .gamma * target * torch.log(pt) - \
( 1 - alpha) * pt * * self .gamma * ( 1 - target) * torch.log( 1 - pt)
if self .reduction = = 'elementwise_mean' :
loss = torch.mean(loss)
elif self .reduction = = 'sum' :
loss = torch. sum (loss)
return loss
|
以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_33278884/article/details/91572173