BP网络实现分类问题

时间:2015-12-14 07:58:55
【文件属性】:

文件名称:       BP网络实现分类问题

文件大小:223KB

文件格式:DOCX

更新时间:2015-12-14 07:58:55

BP解决分类

function main() InDim=2; % 样本输入维数 OutDim=3; % 样本输出维数 % figure % colordef(gcf,'white') % echo off % clc % axis([-2,2,-2,2]) % axis on % grid % xlabel('Input x'); % ylabel('Input y'); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % line([-1 1],[1 1]) % line([1 -1],[1 0]) % line([-1 -1],[0 1]) % line([-1 1],[-0.5 -0.5]) % line([-1 1],[-1.5 -1.5]) % line([1 1],[-0.5 -1.5]) % line([-1 -1],[-0.5 -1.5]) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % hold on % sj=plot([-1 1],[1 1],[1 -1],[1 0],[-1 -1],[0 1]); % hold on % set(sj,'Color','r','LineWidth',4); % js=plot([-1 1],[-0.5 -0.5],'b',[-1 1],[-1.5 -1.5],'b',[1 1],... % [-0.5 -1.5],'b',[-1 -1],[-0.5 -1.5],'b'); % hold on % set(js,'Color','b','LineWidth',4); %hold off figure colordef(gcf,'white') echo off clc axis([-2,2,-2,2]) axis on grid xlabel('Input x'); ylabel('Input y'); hold on sj=plot([-1 1],[1 1],[1 -1],[1 0],[-1 -1],[0 1]); hold on js=plot([-1 1],[-0.5 -0.5],'b',[-1 1],[-1.5 -1.5],'b',[1 1],... [-0.5 -1.5],'b',[-1 -1],[-0.5 -1.5],'b'); hold on set(sj,'Color','r','LineWidth',4); set(js,'Color','b','LineWidth',4); hold on SamNum=400; % 训练样本数 rand('state', sum(100*clock)) SamIn=(rand(2,SamNum)-0.5)*4; % 产生随机样本输入 % 根据目标函数获得训练样本输入输出,并绘制样本 SamOut=[]; for i=1:SamNum Sam=SamIn(:,i); x=Sam(1,1); y=Sam(2,1); if((x>-1)&(x<1))==1 if((y>x/2+1/2)&(y<1))==1 plot(x,y,'k+') class=[0 1 0]'; elseif((y<-0.5)&(y>-1.5))==1 plot(x,y,'ks') class=[0 0 1]'; else plot(x,y,'ko') class=[1 0 0]'; end else plot(x,y,'ko') class=[1 0 0]'; end SamOut=[SamOut class]; end HiddenUnitNum=10; % 隐节点数 MaxEpochs=10000; % 最大训练次数 lr=0.1; % 学习率 E0=0.1; % 目标误差 W1=0.2*rand(HiddenUnitNum,InDim)-0.1; % 输入层到隐层的初始权值 B1=0.2*rand(HiddenUnitNum,1)-0.1; % 隐节点初始偏移 W2=0.2*rand(OutDim,HiddenUnitNum)-0.1; % 隐层到输出层的初始权值 B2=0.2*rand(OutDim,1)-0.1; % 输出层初始偏移 W1Ex=[W1 B1]; % 输入层到隐层的初始权值扩展, 10*3 W2Ex=[W2 B2]; % 隐层到输出层的初始权值, 3*11 SamInEx=[SamIn' ones(SamNum,1)]'; % 样本输入扩展, 3*200 ErrHistory=[]; % 用于记录每次权值调整后的训练误差 for i=1:MaxEpochs % 正向传播计算网络输出 HiddenOut=logsig(W1Ex*SamInEx); HiddenOutEx=[HiddenOut' ones(SamNum, 1)]'; NetworkOut=logsig(W2Ex*HiddenOutEx); % 停止学习判断 Error=SamOut-NetworkOut; SSE=sumsqr(Error); fprintf('Times: %7.0f',i); fprintf(' SSE: .4f\n\n',SSE); % 记录每次权值调整后的训练误差 ErrHistory=[ErrHistory SSE]; if SSE-1)&(x<1))==1 if((y>x/2+1/2)&(y<1))==1 TestTargetOut=[TestTargetOut 2]; elseif((y<-0.5)&(y>-1.5))==1 TestTargetOut=[TestTargetOut 3]; else TestTargetOut=[TestTargetOut 1]; end else TestTargetOut=[TestTargetOut 1]; end end %显示计算结果 NNC1Flag=abs(NNClass-1)<0.1; NNC2Flag=abs(NNClass-2)<0.1; NNC3Flag=abs(NNClass-3)<0.1; TargetC1Flag=abs(TestTargetOut-1)<0.1; TargetC2Flag=abs(TestTargetOut-2)<0.1; TargetC3Flag=abs(TestTargetOut-3)<0.1; Target_C1_num=sum(TargetC1Flag); Target_C2_num=sum(TargetC2Flag); Target_C3_num=sum(TargetC3Flag); Test_C1_num=sum(NNC1Flag); Test_C2_num=sum(NNC2Flag); Test_C3_num=sum(NNC3Flag); Test_C1_C1=1.0*NNC1Flag*TargetC1Flag'; Test_C1_C2=1.0*NNC1Flag*TargetC2Flag'; Test_C1_C3=1.0*NNC1Flag*TargetC3Flag'; Test_C2_C1=1.0*NNC2Flag*TargetC1Flag'; Test_C2_C2=1.0*NNC2Flag*TargetC2Flag'; Test_C2_C3=1.0*NNC2Flag*TargetC3Flag'; Test_C3_C1=1.0*NNC3Flag*TargetC1Flag'; Test_C3_C2=1.0*NNC3Flag*TargetC2Flag'; Test_C3_C3=1.0*NNC3Flag*TargetC3Flag'; Test_Correct=(Test_C1_C1+Test_C2_C2+Test_C3_C3)/TestSamNum; % 输出格式设计 disp('///////////////////////////////////////////////////////////'); fprintf('\n'); disp(' 测试报告'); fprintf('\n'); fprintf('测试样本总数: %7.0f\n\n',TestSamNum); fprintf('第一类样本数: %7.0f\n',Target_C1_num); fprintf('第二类样本数: %7.0f\n',Target_C2_num); fprintf('第三类样本数: %7.0f\n\n',Target_C3_num); disp('= = = = = = = = = = = = = = = = = = = = = = = = = = = '); fprintf('\n'); fprintf('第一类样本分布(C1=%4.0f)\n',Test_C1_num); fprintf(' C11=%4.0f',Test_C1_C1); fprintf(' C12=%4.0f',Test_C1_C2); fprintf(' C13=%4.0f\n\n',Test_C1_C3); fprintf('第二类样本分布(C2=%3.0f)\n',Test_C2_num); fprintf(' C21=%4.0f',Test_C2_C1); fprintf(' C22=%4.0f',Test_C2_C2); fprintf(' C23=%4.0f\n\n',Test_C2_C3); fprintf('第三类样本分布(C3=%3.0f)\n',Test_C3_num); fprintf(' C31=%4.0f',Test_C3_C1); fprintf(' C32=%4.0f',Test_C3_C2); fprintf(' C33=%4.0f\n\n',Test_C3_C3); fprintf('正确率:%6.4f\n\n',Test_Correct); disp('///////////////////////////////////////////////////////////'); fprintf('\n\n');


网友评论

  • 感觉用处不是很大