1.介绍
当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?
1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的
2)forward是表示一个前向传播,构建网络层的先后运算步骤
3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?
当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。
在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
2.代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__( self , in_channels, mid_channels, out_channels):
super (Net, self ).__init__()
self .conv0 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, mid_channels, kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )),
torch.nn.LeakyReLU())
self .conv1 = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, out_channels * 2 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )))
def forward( self , x):
x = self .conv0(x)
x = self .conv1(x)
return x
class Net(nn.Module):
def __init__( self , in_channels, mid_channels, out_channels):
super (Net, self ).__init__()
self .conv0 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, mid_channels, kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )),
torch.nn.LeakyReLU())
self .conv1 = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, out_channels * 2 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )))
def __call__( self , x):
x = self .conv0(x)
x = self .conv1(x)
return x
|
补充:torch/nn目录结构以及__init__.py
torch/nn目录结构以及init.py
torch/nn目录结构
__init__.py:
1
2
3
4
5
6
7
8
9
10
|
from .modules import *
#nn.modules 导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py 定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init 导入init.py 参数初始化
from . import utils
#nn.utils 导入utils目录下内容 官网api下nn.utils下api
|
对于backends, functional.py, _functions 需要在代码前重新Import
例如我们常用的
import torch.nn.functional as F 就是导入了functional.py
backends和_functions是functional.py实现各种函数时所用到的。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。如有错误或未考虑完全的地方,望不吝赐教。
原文链接:https://blog.csdn.net/u013289254/article/details/103826591