文件名称: BP网络实现分类问题
文件大小:223KB
文件格式:DOCX
更新时间:2015-12-14 07:58:55
BP解决分类
function main()
InDim=2; % 样本输入维数
OutDim=3; % 样本输出维数
% figure
% colordef(gcf,'white')
% echo off
% clc
% axis([-2,2,-2,2])
% axis on
% grid
% xlabel('Input x');
% ylabel('Input y');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% line([-1 1],[1 1])
% line([1 -1],[1 0])
% line([-1 -1],[0 1])
% line([-1 1],[-0.5 -0.5])
% line([-1 1],[-1.5 -1.5])
% line([1 1],[-0.5 -1.5])
% line([-1 -1],[-0.5 -1.5])
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% hold on
% sj=plot([-1 1],[1 1],[1 -1],[1 0],[-1 -1],[0 1]);
% hold on
% set(sj,'Color','r','LineWidth',4);
% js=plot([-1 1],[-0.5 -0.5],'b',[-1 1],[-1.5 -1.5],'b',[1 1],...
% [-0.5 -1.5],'b',[-1 -1],[-0.5 -1.5],'b');
% hold on
% set(js,'Color','b','LineWidth',4);
%hold off
figure
colordef(gcf,'white')
echo off
clc
axis([-2,2,-2,2])
axis on
grid
xlabel('Input x');
ylabel('Input y');
hold on
sj=plot([-1 1],[1 1],[1 -1],[1 0],[-1 -1],[0 1]);
hold on
js=plot([-1 1],[-0.5 -0.5],'b',[-1 1],[-1.5 -1.5],'b',[1 1],...
[-0.5 -1.5],'b',[-1 -1],[-0.5 -1.5],'b');
hold on
set(sj,'Color','r','LineWidth',4);
set(js,'Color','b','LineWidth',4);
hold on
SamNum=400; % 训练样本数
rand('state', sum(100*clock))
SamIn=(rand(2,SamNum)-0.5)*4; % 产生随机样本输入
% 根据目标函数获得训练样本输入输出,并绘制样本
SamOut=[];
for i=1:SamNum
Sam=SamIn(:,i);
x=Sam(1,1);
y=Sam(2,1);
if((x>-1)&(x<1))==1
if((y>x/2+1/2)&(y<1))==1
plot(x,y,'k+')
class=[0 1 0]';
elseif((y<-0.5)&(y>-1.5))==1
plot(x,y,'ks')
class=[0 0 1]';
else
plot(x,y,'ko')
class=[1 0 0]';
end
else
plot(x,y,'ko')
class=[1 0 0]';
end
SamOut=[SamOut class];
end
HiddenUnitNum=10; % 隐节点数
MaxEpochs=10000; % 最大训练次数
lr=0.1; % 学习率
E0=0.1; % 目标误差
W1=0.2*rand(HiddenUnitNum,InDim)-0.1; % 输入层到隐层的初始权值
B1=0.2*rand(HiddenUnitNum,1)-0.1; % 隐节点初始偏移
W2=0.2*rand(OutDim,HiddenUnitNum)-0.1; % 隐层到输出层的初始权值
B2=0.2*rand(OutDim,1)-0.1; % 输出层初始偏移
W1Ex=[W1 B1]; % 输入层到隐层的初始权值扩展, 10*3
W2Ex=[W2 B2]; % 隐层到输出层的初始权值, 3*11
SamInEx=[SamIn' ones(SamNum,1)]'; % 样本输入扩展, 3*200
ErrHistory=[]; % 用于记录每次权值调整后的训练误差
for i=1:MaxEpochs
% 正向传播计算网络输出
HiddenOut=logsig(W1Ex*SamInEx);
HiddenOutEx=[HiddenOut' ones(SamNum, 1)]';
NetworkOut=logsig(W2Ex*HiddenOutEx);
% 停止学习判断
Error=SamOut-NetworkOut;
SSE=sumsqr(Error);
fprintf('Times: %7.0f',i);
fprintf(' SSE: .4f\n\n',SSE);
% 记录每次权值调整后的训练误差
ErrHistory=[ErrHistory SSE];
if SSE