LSTM 时间序列预测 matlab

时间:2022-04-20 17:19:23


由于参加了一个小的课题,是关于时间序列预测的。平时习惯用matlab, 网上这种资源就比较少。

借鉴了  http://blog.csdn.net/u010540396/article/details/52797489  的内容,稍微修改了一下程序。


程序说明:DATA.mat 是一行时序值,

numdely 是用前numdely个点预测当前点,cell_num是隐含层的数目,cost_gate 是误差的阈值。

直接在命令行输入RunLstm(numdely,cell_num,cost_gate)即可。

 
function [r1, r2] = RunLstm(numdely,cell_num,cost_gate)%% 数据加载,并归一化处理figure;[train_data,test_data]=LSTM_data_process(numdely);data_length=size(train_data,1)-1;data_num=size(train_data,2);%% 网络参数初始化% 结点数设置input_num=data_length;% cell_num=5;output_num=1;% 网络中门的偏置bias_input_gate=rand(1,cell_num);bias_forget_gate=rand(1,cell_num);bias_output_gate=rand(1,cell_num);%网络权重初始化ab=20;weight_input_x=rand(input_num,cell_num)/ab;weight_input_h=rand(output_num,cell_num)/ab;weight_inputgate_x=rand(input_num,cell_num)/ab;weight_inputgate_c=rand(cell_num,cell_num)/ab;weight_forgetgate_x=rand(input_num,cell_num)/ab;weight_forgetgate_c=rand(cell_num,cell_num)/ab;weight_outputgate_x=rand(input_num,cell_num)/ab;weight_outputgate_c=rand(cell_num,cell_num)/ab;%hidden_output权重weight_preh_h=rand(cell_num,output_num);%网络状态初始化% cost_gate=0.25;h_state=rand(output_num,data_num);cell_state=rand(cell_num,data_num);%% 网络训练学习for iter=1:100    yita=0.01;            %每次迭代权重调整比例    for m=1:data_num        %前馈部分        if(m==1)            gate=tanh(train_data(1:input_num,m)'*weight_input_x);            input_gate_input=train_data(1:input_num,m)'*weight_inputgate_x+bias_input_gate;            output_gate_input=train_data(1:input_num,m)'*weight_outputgate_x+bias_output_gate;            for n=1:cell_num                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));            end            forget_gate=zeros(1,cell_num);            forget_gate_input=zeros(1,cell_num);            cell_state(:,m)=(input_gate.*gate)';        else            gate=tanh(train_data(1:input_num,m)'*weight_input_x+h_state(:,m-1)'*weight_input_h);            input_gate_input=train_data(1:input_num,m)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;            forget_gate_input=train_data(1:input_num,m)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;            output_gate_input=train_data(1:input_num,m)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;            for n=1:cell_num                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));                forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));            end            cell_state(:,m)=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';           end        pre_h_state=tanh(cell_state(:,m)').*output_gate;        h_state(:,m)=(pre_h_state*weight_preh_h)';     end    % 误差的计算%     Error=h_state(:,m)-train_data(end,m);    Error=h_state(:,:)-train_data(end,:);    Error_Cost(1,iter)=sum(Error.^2);    if Error_Cost(1,iter) < cost_gate            iter        break;    end                 [ weight_input_x,...                weight_input_h,...                weight_inputgate_x,...                weight_inputgate_c,...                weight_forgetgate_x,...                weight_forgetgate_c,...                weight_outputgate_x,...                weight_outputgate_c,...                weight_preh_h ]=LSTM_updata_weight(m,yita,Error,...                                                   weight_input_x,...                                                   weight_input_h,...                                                   weight_inputgate_x,...                                                   weight_inputgate_c,...                                                   weight_forgetgate_x,...                                                   weight_forgetgate_c,...                                                   weight_outputgate_x,...                                                   weight_outputgate_c,...                                                   weight_preh_h,...                                                   cell_state,h_state,...                                                   input_gate,forget_gate,...                                                   output_gate,gate,...                                                   train_data,pre_h_state,...                                                   input_gate_input,...                                                   output_gate_input,...                                                   forget_gate_input);end%% 绘制Error-Cost曲线图for n=1:1:iter    semilogy(n,Error_Cost(1,n),'*');    hold on;    title('Error-Cost曲线图');   end%% 数据检验%数据加载test_final=test_data;test_final=test_final/sqrt(sum(test_final.^2));total = sqrt(sum(test_data.^2));test_output=test_data(:,end);%前馈m=data_num;gate=tanh(test_final(1:input_num)'*weight_input_x+h_state(:,m-1)'*weight_input_h);input_gate_input=test_final(1:input_num)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;forget_gate_input=test_final(1:input_num)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;output_gate_input=test_final(1:input_num)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;for n=1:cell_num    input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));    forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));    output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));endcell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';pre_h_state=tanh(cell_state_test').*output_gate;h_state_test=(pre_h_state*weight_preh_h)'* total;test_output(end);test = sprintf('----Test result is %s----' ,num2str(h_state_test));true = sprintf('----True result is %s----' ,num2str(test_output(end)));disp(test);disp(true);

function [train_data,test_data]=LSTM_data_process(numdely)load('DATA.mat');numdata = size(a,1);numsample = numdata - numdely - 1;train_data = zeros(numdely+1, numsample);test_data = zeros(numdely+1,1);for i = 1 :numsample    train_data(:,i) = a(i:i+numdely)';endtest_data = a(numdata-numdely: numdata);data_length=size(train_data,1);          data_num=size(train_data,2);           % %%归一化过程for n=1:data_num    train_data(:,n)=train_data(:,n)/sqrt(sum(train_data(:,n).^2));  end% for m=1:size(test_data,2)%     test_data(:,m)=test_data(:,m)/sqrt(sum(test_data(:,m).^2));% end


function [   weight_input_x,weight_input_h,weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h ]=LSTM_updata_weight(n,yita,Error,...
weight_input_x, weight_input_h, weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h,...
cell_state,h_state,input_gate,forget_gate,output_gate,gate,train_data,pre_h_state,input_gate_input, output_gate_input,forget_gate_input)

data_length=size(train_data,1) - 1;
data_num=size(train_data,2);
weight_preh_h_temp=weight_preh_h;


%%% 权重更新函数
input_num=data_length;
cell_num=size(weight_preh_h_temp,1);
output_num=1;

%% 更新weight_preh_h权重
for m=1:output_num
delta_weight_preh_h_temp(:,m)=2*Error(m,1)*pre_h_state;
end
weight_preh_h_temp=weight_preh_h_temp-yita*delta_weight_preh_h_temp;

%% 更新weight_outputgate_x
for num=1:output_num
for m=1:data_length
delta_weight_outputgate_x(m,:)=(2*weight_preh_h(:,num)*Error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train_data(m,n);
end
weight_outputgate_x=weight_outputgate_x-yita*delta_weight_outputgate_x;
end
%% 更新weight_inputgate_x
for num=1:output_num
for m=1:data_length
delta_weight_inputgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train_data(m,n);
end
weight_inputgate_x=weight_inputgate_x-yita*delta_weight_inputgate_x;
end


if(n~=1)
%% 更新weight_input_x
temp=train_data(1:input_num,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
end
weight_input_x=weight_input_x-yita*delta_weight_input_x;
end
%% 更新weight_forgetgate_x
for num=1:output_num
for m=1:data_length
delta_weight_forgetgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train_data(m,n);
end
weight_forgetgate_x=weight_forgetgate_x-yita*delta_weight_forgetgate_x;
end
%% 更新weight_inputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_inputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
end
weight_inputgate_c=weight_inputgate_c-yita*delta_weight_inputgate_c;
end
%% 更新weight_forgetgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_forgetgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
end
weight_forgetgate_c=weight_forgetgate_c-yita*delta_weight_forgetgate_c;
end
%% 更新weight_outputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_outputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
end
weight_outputgate_c=weight_outputgate_c-yita*delta_weight_outputgate_c;
end
%% 更新weight_input_h
temp=train_data(1:input_num,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
for num=1:output_num
for m=1:output_num
delta_weight_input_h(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
end
weight_input_h=weight_input_h-yita*delta_weight_input_h;
end
else
%% 更新weight_input_x
temp=train_data(1:input_num,n)'*weight_input_x;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
end
weight_input_x=weight_input_x-yita*delta_weight_input_x;
end
end
weight_preh_h=weight_preh_h_temp;

end


---------------------------------------2017.08.03 UPDATE----------------------------------------

代码数据链接:

http://download.csdn.net/detail/u011060119/9919621