设计神经网络的一般步骤:
1. 设计框架
2. 设计骨干网络
Unet网络设计的步骤:
1. 设计Unet网络工厂模式
2. 设计编解码结构
3. 设计卷积模块
4. unet实例模块
Unet网络最重要的特征:
1. 编解码结构。
2. 解码结构,比FCN更加完善,采用连接方式。
3. 本质是一个框架,编码部分可以使用很多图像分类网络。
示例代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import torch
import torch.nn as nn
class Unet(nn.Module):
#初始化参数:Encoder,Decoder,bridge
#bridge默认值为无,如果有参数传入,则用该参数替换None
def __init__( self ,Encoder,Decoder,bridge = None ):
super (Unet, self ).__init__()
self .encoder = Encoder(encoder_blocks)
self .decoder = Decoder(decoder_blocks)
self .bridge = bridge
def forward( self ,x):
res = self .encoder(x)
out,skip = res[ 0 ],res[ 1 ,:]
if bridge is not None :
out = bridge(out)
out = self .decoder(out,skip)
return out
#设计编码模块
class Encoder(nn.Module):
def __init__( self ,blocks):
super (Encoder, self ).__init__()
#assert:断言函数,避免出现参数错误
assert len (blocks) > 0
#nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
self .blocks = nn.Modulelist(blocks)
def forward( self ,x):
skip = []
for i in range ( len ( self .blocks) - 1 ):
x = self .blocks[i](x)
skip.append(x)
res = [ self .block[i + 1 ](x)]
#列表之间可以通过+号拼接
res + = skip
return res
#设计Decoder模块
class Decoder(nn.Module):
def __init__( self ,blocks):
super (Decoder, self ).__init__()
assert len (blocks) > 0
self .blocks = nn.Modulelist(blocks)
def ceter_crop( self ,skips,x):
_,_,height1,width1 = skips.shape()
_,_,height2,width2 = x.shape()
#对图像进行剪切处理,拼接的时候保持对应size参数一致
ht,wt = min (height1,height2), min (width1,width2)
dh1 = (height1 - height2) / / 2 if height1 > height2 else 0
dw1 = (width1 - width2) / / 2 if width1 > width2 else 0
dh2 = (height2 - height1) / / 2 if height2 > height1 else 0
dw2 = (width2 - width1) / / 2 if width2 > width1 else 0
return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\
x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]
def forward( self , skips,x,reverse_skips = True ):
assert len (skips) = = len (blocks) - 1
if reverse_skips is True :
skips = skips[: : - 1 ]
x = self .blocks[ 0 ](x)
for i in range ( 1 , len ( self .blocks)):
skip = skips[i - 1 ]
x = torch.cat(skip,x, 1 )
x = self .blocks[i](x)
return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0 ):
#nn.Sequential:与Modulelist相比,包含了forward函数
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernal_size = 3 , padding = padding, bias = False ),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace = True ),
nn.Conv2d(in_channels, out_channels, kernal_size = 3 , padding = padding, bias = False ),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace = True ),
)
#实例化Unet模型
def unet(in_channels,out_channels):
encoder_blocks = [unet_convs(in_channels, 64 ),\
nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ),\
unet_convs( 64 , 128 )), \
nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ), \
unet_convs( 128 , 256 )),
nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ), \
unet_convs( 256 , 512 )),
]
bridge = nn.Sequential(unet_convs( 512 , 1024 ))
decoder_blocks = [nn.conTranpose2d( 1024 , 512 ), \
nn.Sequential(unet_convs( 1024 , 512 ),
nn.conTranpose2d( 512 , 256 )),\
nn.Sequential(unet_convs( 512 , 256 ),
nn.conTranpose2d( 256 , 128 )), \
nn.Sequential(unet_convs( 512 , 256 ),
nn.conTranpose2d( 256 , 128 )), \
nn.Sequential(unet_convs( 256 , 128 ),
nn.conTranpose2d( 128 , 64 ))
]
return Unet(encoder_blocks,decoder_blocks,bridge)
|
补充知识:Pytorch搭建U-Net网络
U-Net: Convolutional Networks for Biomedical Image Segmentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary
class DoubleConv(nn.Module):
def __init__( self , in_ch, out_ch):
super (DoubleConv, self ).__init__()
self .conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3 , padding = 0 ),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True ),
nn.Conv2d(out_ch, out_ch, 3 , padding = 0 ),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True )
)
def forward( self , input ):
return self .conv( input )
class Unet(nn.Module):
def __init__( self , in_ch, out_ch):
super (Unet, self ).__init__()
self .conv1 = DoubleConv(in_ch, 64 )
self .pool1 = nn.MaxPool2d( 2 )
self .conv2 = DoubleConv( 64 , 128 )
self .pool2 = nn.MaxPool2d( 2 )
self .conv3 = DoubleConv( 128 , 256 )
self .pool3 = nn.MaxPool2d( 2 )
self .conv4 = DoubleConv( 256 , 512 )
self .pool4 = nn.MaxPool2d( 2 )
self .conv5 = DoubleConv( 512 , 1024 )
# 逆卷积,也可以使用上采样
self .up6 = nn.ConvTranspose2d( 1024 , 512 , 2 , stride = 2 )
self .conv6 = DoubleConv( 1024 , 512 )
self .up7 = nn.ConvTranspose2d( 512 , 256 , 2 , stride = 2 )
self .conv7 = DoubleConv( 512 , 256 )
self .up8 = nn.ConvTranspose2d( 256 , 128 , 2 , stride = 2 )
self .conv8 = DoubleConv( 256 , 128 )
self .up9 = nn.ConvTranspose2d( 128 , 64 , 2 , stride = 2 )
self .conv9 = DoubleConv( 128 , 64 )
self .conv10 = nn.Conv2d( 64 , out_ch, 1 )
def forward( self , x):
c1 = self .conv1(x)
crop1 = c1[:,:, 88 : 480 , 88 : 480 ]
p1 = self .pool1(c1)
c2 = self .conv2(p1)
crop2 = c2[:,:, 40 : 240 , 40 : 240 ]
p2 = self .pool2(c2)
c3 = self .conv3(p2)
crop3 = c3[:,:, 16 : 120 , 16 : 120 ]
p3 = self .pool3(c3)
c4 = self .conv4(p3)
crop4 = c4[:,:, 4 : 60 , 4 : 60 ]
p4 = self .pool4(c4)
c5 = self .conv5(p4)
up_6 = self .up6(c5)
merge6 = torch.cat([up_6, crop4], dim = 1 )
c6 = self .conv6(merge6)
up_7 = self .up7(c6)
merge7 = torch.cat([up_7, crop3], dim = 1 )
c7 = self .conv7(merge7)
up_8 = self .up8(c7)
merge8 = torch.cat([up_8, crop2], dim = 1 )
c8 = self .conv8(merge8)
up_9 = self .up9(c8)
merge9 = torch.cat([up_9, crop1], dim = 1 )
c9 = self .conv9(merge9)
c10 = self .conv10(c9)
out = nn.Sigmoid()(c10)
return out
if __name__ = = "__main__" :
test_input = torch.rand( 1 , 1 , 572 , 572 )
model = Unet(in_ch = 1 , out_ch = 2 )
summary(model, ( 1 , 572 , 572 ))
ouput = model(test_input)
print (ouput.size())
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Layer ( type ) Output Shape Param #
= = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
Conv2d - 1 [ - 1 , 64 , 570 , 570 ] 640
BatchNorm2d - 2 [ - 1 , 64 , 570 , 570 ] 128
ReLU - 3 [ - 1 , 64 , 570 , 570 ] 0
Conv2d - 4 [ - 1 , 64 , 568 , 568 ] 36 , 928
BatchNorm2d - 5 [ - 1 , 64 , 568 , 568 ] 128
ReLU - 6 [ - 1 , 64 , 568 , 568 ] 0
DoubleConv - 7 [ - 1 , 64 , 568 , 568 ] 0
MaxPool2d - 8 [ - 1 , 64 , 284 , 284 ] 0
Conv2d - 9 [ - 1 , 128 , 282 , 282 ] 73 , 856
BatchNorm2d - 10 [ - 1 , 128 , 282 , 282 ] 256
ReLU - 11 [ - 1 , 128 , 282 , 282 ] 0
Conv2d - 12 [ - 1 , 128 , 280 , 280 ] 147 , 584
BatchNorm2d - 13 [ - 1 , 128 , 280 , 280 ] 256
ReLU - 14 [ - 1 , 128 , 280 , 280 ] 0
DoubleConv - 15 [ - 1 , 128 , 280 , 280 ] 0
MaxPool2d - 16 [ - 1 , 128 , 140 , 140 ] 0
Conv2d - 17 [ - 1 , 256 , 138 , 138 ] 295 , 168
BatchNorm2d - 18 [ - 1 , 256 , 138 , 138 ] 512
ReLU - 19 [ - 1 , 256 , 138 , 138 ] 0
Conv2d - 20 [ - 1 , 256 , 136 , 136 ] 590 , 080
BatchNorm2d - 21 [ - 1 , 256 , 136 , 136 ] 512
ReLU - 22 [ - 1 , 256 , 136 , 136 ] 0
DoubleConv - 23 [ - 1 , 256 , 136 , 136 ] 0
MaxPool2d - 24 [ - 1 , 256 , 68 , 68 ] 0
Conv2d - 25 [ - 1 , 512 , 66 , 66 ] 1 , 180 , 160
BatchNorm2d - 26 [ - 1 , 512 , 66 , 66 ] 1 , 024
ReLU - 27 [ - 1 , 512 , 66 , 66 ] 0
Conv2d - 28 [ - 1 , 512 , 64 , 64 ] 2 , 359 , 808
BatchNorm2d - 29 [ - 1 , 512 , 64 , 64 ] 1 , 024
ReLU - 30 [ - 1 , 512 , 64 , 64 ] 0
DoubleConv - 31 [ - 1 , 512 , 64 , 64 ] 0
MaxPool2d - 32 [ - 1 , 512 , 32 , 32 ] 0
Conv2d - 33 [ - 1 , 1024 , 30 , 30 ] 4 , 719 , 616
BatchNorm2d - 34 [ - 1 , 1024 , 30 , 30 ] 2 , 048
ReLU - 35 [ - 1 , 1024 , 30 , 30 ] 0
Conv2d - 36 [ - 1 , 1024 , 28 , 28 ] 9 , 438 , 208
BatchNorm2d - 37 [ - 1 , 1024 , 28 , 28 ] 2 , 048
ReLU - 38 [ - 1 , 1024 , 28 , 28 ] 0
DoubleConv - 39 [ - 1 , 1024 , 28 , 28 ] 0
ConvTranspose2d - 40 [ - 1 , 512 , 56 , 56 ] 2 , 097 , 664
Conv2d - 41 [ - 1 , 512 , 54 , 54 ] 4 , 719 , 104
BatchNorm2d - 42 [ - 1 , 512 , 54 , 54 ] 1 , 024
ReLU - 43 [ - 1 , 512 , 54 , 54 ] 0
Conv2d - 44 [ - 1 , 512 , 52 , 52 ] 2 , 359 , 808
BatchNorm2d - 45 [ - 1 , 512 , 52 , 52 ] 1 , 024
ReLU - 46 [ - 1 , 512 , 52 , 52 ] 0
DoubleConv - 47 [ - 1 , 512 , 52 , 52 ] 0
ConvTranspose2d - 48 [ - 1 , 256 , 104 , 104 ] 524 , 544
Conv2d - 49 [ - 1 , 256 , 102 , 102 ] 1 , 179 , 904
BatchNorm2d - 50 [ - 1 , 256 , 102 , 102 ] 512
ReLU - 51 [ - 1 , 256 , 102 , 102 ] 0
Conv2d - 52 [ - 1 , 256 , 100 , 100 ] 590 , 080
BatchNorm2d - 53 [ - 1 , 256 , 100 , 100 ] 512
ReLU - 54 [ - 1 , 256 , 100 , 100 ] 0
DoubleConv - 55 [ - 1 , 256 , 100 , 100 ] 0
ConvTranspose2d - 56 [ - 1 , 128 , 200 , 200 ] 131 , 200
Conv2d - 57 [ - 1 , 128 , 198 , 198 ] 295 , 040
BatchNorm2d - 58 [ - 1 , 128 , 198 , 198 ] 256
ReLU - 59 [ - 1 , 128 , 198 , 198 ] 0
Conv2d - 60 [ - 1 , 128 , 196 , 196 ] 147 , 584
BatchNorm2d - 61 [ - 1 , 128 , 196 , 196 ] 256
ReLU - 62 [ - 1 , 128 , 196 , 196 ] 0
DoubleConv - 63 [ - 1 , 128 , 196 , 196 ] 0
ConvTranspose2d - 64 [ - 1 , 64 , 392 , 392 ] 32 , 832
Conv2d - 65 [ - 1 , 64 , 390 , 390 ] 73 , 792
BatchNorm2d - 66 [ - 1 , 64 , 390 , 390 ] 128
ReLU - 67 [ - 1 , 64 , 390 , 390 ] 0
Conv2d - 68 [ - 1 , 64 , 388 , 388 ] 36 , 928
BatchNorm2d - 69 [ - 1 , 64 , 388 , 388 ] 128
ReLU - 70 [ - 1 , 64 , 388 , 388 ] 0
DoubleConv - 71 [ - 1 , 64 , 388 , 388 ] 0
Conv2d - 72 [ - 1 , 2 , 388 , 388 ] 130
= = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
Total params: 31 , 042 , 434
Trainable params: 31 , 042 , 434
Non - trainable params: 0
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Input size (MB): 1.25
Forward / backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
torch.Size([ 1 , 2 , 388 , 388 ])
|
以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_38410551/article/details/104294545