任务要求:
自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:
1
2
3
|
import torch
from torch.autograd import Function
from torch.autograd import Variable
|
定义二值化函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
class BinarizedF(Function):
def forward( self , input ):
self .save_for_backward( input )
a = torch.ones_like( input )
b = - torch.ones_like( input )
output = torch.where( input > = 0 ,a,b)
return output
def backward( self , output_grad):
input , = self .saved_tensors
input_abs = torch. abs ( input )
ones = torch.ones_like( input )
zeros = torch.zeros_like( input )
input_grad = torch.where(input_abs< = 1 ,ones, zeros)
return input_grad
|
定义一个module
1
2
3
4
5
6
7
8
|
class BinarizedModule(nn.Module):
def __init__( self ):
super (BinarizedModule, self ).__init__()
self .BF = BinarizedF()
def forward( self , input ):
print ( input .shape)
output = self .BF( input )
return output
|
进行测试
1
2
3
4
5
|
a = Variable(torch.randn( 4 , 480 , 640 ), requires_grad = True )
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print (a)
print (a.grad)
|
其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s
1
2
3
4
5
6
7
8
9
10
11
12
|
class BinarizedF(Function):
def forward( self , input ):
self .save_for_backward( input )
output = torch.ones_like( input )
output[ input < 0 ] = - 1
return output
def backward( self , output_grad):
input , = self .saved_tensors
input_grad = output_grad.clone()
input_abs = torch. abs ( input )
input_grad[input_abs> 1 ] = 0
return input_grad
|
以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_42696356/article/details/100899711