天啦噜!!我发现更新的pytorch已经有instance normalization了!!
不用自己折腾了!!
-2017.5.25
利用 nn.Module 里的 子类 _BatchNorm (在torch.nn.modules.batchnorm中定义),可以实现各种需求的normalize。
在docs里,可以看到,有3种normalization layer,但其实他们都是继承了_BatchNorm这个类的,所以我们看看BatchNorm2d,就可以对其他的方法举一反三啦~
先来看看文档
不清楚没关系,接下来用例子讲解:
创建一个BatchNorm2d的实例的方法如下
import torch.nn as nn
norm = nn.BatchNorm2d(fea_num, affine=False).cuda()
其中,fea_num 是拉出来的维度,就是说按照 fea_num 的维度,其他维度拉成一长条来normalize,fea_num对应input的第1个(维度从0开始计)维度, 所以两者的值应相等。.cuda()是把这个module放到gpu上。
在普通的batch normalize的情况下
input是(batchsize,channel,height,width)=(4,3,5,5)来看,fea_num对应channel。所以channel=0时,求一次mean,var,做一次normalize;channel=1时,求一次。。channel=2时,求一次。。
在训练中,还有两个可以学习的参数gamma & beta,所以在gamma & beta设定为可变参数的情况下,应该这样创建和使用batchnorm layer:
#input is cuda float Variable of batchsize x channel x height x width
#train state
norm = nn.BatchNorm2d(channel).cuda()#默认affine=True
input = norm(input)
注意:
- 在train之前正确的初始化可变参数
- 在test/eval 模式下,应该用.eval() 固定住可变参数。
- 一个input的测试例子:
import numpy as np
from torch.autograd import Variable
BS = 2
C = 3
H = 2
W = 2
input = np.arange(BS*C*H*W)
input.resize(BS, C, H, W)
input = Variable(torch.from_numpy(input).clone().float()).cuda().contiguous()
如果不需要可变参数 gamma & beta,那直接:
#input is cuda float Variable of batchsize x channel x height x width
norm = nn.BatchNorm2d(channel, affine=False).cuda()
input = norm(input)
其他情况的normalize,如instance normalize
input还是(batchsize,channel,height,width)=(4,3,5,5)假设我们想把batchsize这一个维度拉出来,对每一个instance(batchsize=0~3)看做(3,5,5)的3D tensor 求一次normalize,那怎么做呢?其实很简单,把input的第0维和第1维调换一下就好了。
#input is cuda float Variable of batchsize x channel x height x width
instanceNorm = nn.BatchNorm2d(BS, affine=False).cuda()
input = input .transpose(0,1).contiguous()
input = instanceNorm(input)
input = input .transpose(0,1).contiguous()
注意:
- affine参数看需求设定,注意事项同普通batch normalize情况
- 如果没使用.contiguous(),很有可能报错
RuntimeError: Assertion `THCTensor_(isContiguous)(state, t)' failed. at **/pytorch/torch/lib/THCUNN/generic/BatchNormalization.cu:20
总而言之,记得BatchNorm layer 的 fea_num的取值=input拉出来的那个维度的大小,且该维度应该是input的第1维,如果不是,用resize、transpose、unsqueeze啥的搞到是就好了