相关算法
python代码参考http://blog.csdn.net/zc02051126/article/details/9668439#(作少量修改与注释)
1 #coding:utf8 2 import matplotlib.pylab as plt 3 import numpy as np 4 import cPickle 5 6 7 class RBM: 8 def __init__(self,n_visul, n_hidden, max_epoch = 50, batch_size = 110, penalty = 2e-4): 9 self.n_visible = n_visul 10 self.n_hidden = n_hidden 11 self.max_epoch = max_epoch 12 self.batch_size = batch_size 13 self.penalty = penalty 14 self.w = np.random.random((self.n_visible, self.n_hidden)) * 0.1 15 self.v_bias = np.zeros((1, self.n_visible)) 16 self.h_bias = np.zeros((1, self.n_hidden)) 17 18 def sigmoid(self, z): 19 return 1.0 / (1.0 + np.exp( -z )) 20 21 def forward(self, vis): 22 return self.sigmoid(np.dot(vis.T, self.w) + self.h_bias) 23 24 def backward(self, vis): 25 return self.sigmoid(np.dot(vis, self.w.T) + self.v_bias) 26 27 def batch(self): 28 d, N = self.x.shape 29 num_batchs = int(round(N / self.batch_size)) + 1 30 groups = np.ravel(np.repeat([range(0, num_batchs)], self.batch_size, axis = 0)) 31 groups=groups[:N] 32 np.random.shuffle(groups) 33 batch_data = [] 34 for i in range(0, num_batchs): 35 index = groups == i 36 batch_data.append(self.x[:, index]) 37 return batch_data 38 39 def rbmBB(self, x): 40 self.x = x 41 eta = 0.1 42 momentum = 0.5 #动量项 43 W = self.w 44 b = self.h_bias 45 c = self.v_bias 46 Winc = np.zeros((self.n_visible, self.n_hidden)) 47 binc = np.zeros(self.n_hidden) 48 cinc = np.zeros(self.n_visible) 49 batch_data = self.batch() 50 num_batch = len(batch_data) 51 errors = [] 52 for epoch in range(0, self.max_epoch): 53 err_sum = 0.0 54 for batch in range(0, num_batch): 55 num_dims, num_cases = batch_data[batch].shape 56 data = batch_data[batch] 57 # 已知可见层,采样出隐藏层 58 ph = self.forward(data) 59 ph_states = np.zeros((num_cases, self.n_hidden)) 60 ph_states[ph > np.random.random((num_cases, self.n_hidden))] = 1 61 # 已知隐藏层,采样出可见层 62 neg_data = self.backward(ph_states) 63 neg_data_states = np.zeros((num_cases, num_dims)) 64 neg_data_states[neg_data > np.random.random((num_cases, num_dims))] = 1 65 neg_data_states = neg_data_states.transpose() 66 nh = self.forward(neg_data_states) 67 # CD算法 68 dW = np.dot(data, ph) - np.dot(neg_data_states, nh) 69 dc = np.sum(data, axis = 1) - np.sum(neg_data_states, axis = 1) 70 db = np.sum(ph, axis = 0) - np.sum(nh, axis = 0) 71 # 刷新参数 72 Winc = momentum * Winc + eta * (dW / num_cases - self.penalty * W) 73 binc = momentum * binc + eta * (db / num_cases); 74 cinc = momentum * cinc + eta * (dc / num_cases); 75 W = W + Winc 76 b = b + binc 77 c = c + cinc 78 self.w = W 79 self.h_bais = b 80 self.v_bias = c 81 err = np.linalg.norm(data - neg_data.transpose()) 82 err_sum += err 83 print epoch, err_sum 84 errors.append(err_sum) 85 self.errors = errors 86 self.hiden_value = self.forward(self.x) 87 h_row, h_col = self.hiden_value.shape 88 hiden_states = np.zeros((h_row, h_col)) 89 hiden_states[self.hiden_value > np.random.random((h_row, h_col))] = 1 90 self.rebuild_value = self.backward(hiden_states) 91 92 def visualize(self, X): #可视化 93 D, N = X.shape 94 s = int(np.sqrt(D)) 95 num = int(np.ceil(np.sqrt(N))) 96 a = np.zeros((num*s + num + 1, num * s + num + 1)) - 1.0 97 x = 0 98 y = 0 99 for i in range(0, N): 100 z = X[:,i] 101 z = z.reshape(s,s,order='F') 102 z = z.transpose() 103 a[x*s+x:x*s+s+x , y*s+y:y*s+s+y] = z 104 x = x + 1 105 if(x >= num): 106 x = 0 107 y = y + 1 108 return a 109 110 def readData(path): 111 data = [] 112 for line in open(path, 'r'): 113 ele = line.split(' ') 114 tmp = [] 115 for e in ele: 116 if e != '': 117 tmp.append(float(e.strip(' '))) 118 data.append(tmp) 119 return data 120 121 if __name__ == '__main__': 122 f = open('mnist.pkl', 'rb') 123 training_data, validation_data, test_data = cPickle.load(f) 124 training_inputs = [np.reshape(x, 784) for x in training_data[0]] 125 data =training_inputs[:5000] 126 data = np.array(data) 127 data = data.transpose() 128 rbm = Rbm(784, 100,max_epoch = 50) 129 rbm.rbmBB(data) 130 131 a = rbm.visualize(data) #(2060L, 2060L) 132 fig = plt.figure(1) 133 ax = fig.add_subplot(111) 134 ax.imshow(a) 135 plt.title('original data') 136 137 rebuild_value = rbm.rebuild_value.transpose() 138 b = rbm.visualize(rebuild_value) #(2060L, 2060L) 139 fig = plt.figure(2) 140 ax = fig.add_subplot(111) 141 ax.imshow(b) 142 plt.title('rebuild data') 143 144 hidden_value = rbm.hiden_value.transpose() 145 c = rbm.visualize(hidden_value) #(782L, 782L) 146 fig = plt.figure(3) 147 ax = fig.add_subplot(111) 148 ax.imshow(c) 149 plt.title('hidden data') 150 151 w_value = rbm.w 152 d = rbm.visualize(w_value) #(291L, 291L) 153 fig = plt.figure(4) 154 ax = fig.add_subplot(111) 155 ax.imshow(d) 156 plt.title('weight value(w)') 157 plt.show()