首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
|
import numpy as np
class Rnn():
def __init__( self , input_size, hidden_size, num_layers, bidirectional = False ):
self .input_size = input_size
self .hidden_size = hidden_size
self .num_layers = num_layers
self .bidirectional = bidirectional
def feed( self , x):
'''
:param x: [seq, batch_size, embedding]
:return: out, hidden
'''
# x.shape [sep, batch, feature]
# hidden.shape [hidden_size, batch]
# Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
# Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]
out = []
x, hidden = np.array(x), [np.zeros(( self .hidden_size, x.shape[ 1 ])) for i in range ( self .num_layers)]
Wih = [np.random.random(( self .hidden_size, self .hidden_size)) for i in range ( 1 , self .num_layers)]
Wih.insert( 0 , np.random.random(( self .hidden_size, x.shape[ 2 ])))
Whh = [np.random.random(( self .hidden_size, self .hidden_size)) for i in range ( self .num_layers)]
time = x.shape[ 0 ]
for i in range (time):
hidden[ 0 ] = np.tanh((np.dot(Wih[ 0 ], np.transpose(x[i, ...], ( 1 , 0 ))) +
np.dot(Whh[ 0 ], hidden[ 0 ])
))
for i in range ( 1 , self .num_layers):
hidden[i] = np.tanh((np.dot(Wih[i], hidden[i - 1 ]) +
np.dot(Whh[i], hidden[i])
))
out.append(hidden[ self .num_layers - 1 ])
return np.array(out), np.array(hidden)
def sigmoid(x):
return 1.0 / ( 1.0 + 1.0 / np.exp(x))
if __name__ = = '__main__' :
rnn = Rnn( 1 , 5 , 4 )
input = np.random.random(( 6 , 2 , 1 ))
out, h = rnn.feed( input )
print (f 'seq is {input.shape[0]}, batch_size is {input.shape[1]} ' , 'out.shape ' , out.shape, ' h.shape ' , h.shape)
# print(sigmoid(np.random.random((2, 3))))
#
# element-wise multiplication
# print(np.array([1, 2])*np.array([2, 1]))
|
到此这篇关于numpy实现RNN原理实现的文章就介绍到这了,更多相关numpy实现RNN内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!
原文链接:https://blog.csdn.net/qq_43056256/article/details/114272542