logistic回归 求分界线方程

时间:2022-11-04 23:52:33

刚接触机器学习,读书之余编小程序测试效果。贴上供学习用的简易代码,仅供刚开始学习的新手做测试及理解算法用。菜鸟也要飞

没有写画分界线的代码,求得分界线方程后可用这个网站画出分界线图像: http://zh.numberempire.com/graphingcalculator.php

测试数据:坐标系中随便6个点的坐标。分成两类。标签分别设置成1,0。

#include <iostream>
#include <cmath>


using namespace std;

int main(){

double d1[3] = { 1, 1, 3 };//点一(1,3)
double d2[3] = { 1, 3, 1 };//点二(3,1)
double d3[3] = { 1, 4, 8 };//点三(4,8)
double d4[3] = { 1, 3, 5};//点四(3,5)
double d5[3] = { 1, 7, 7 };//点五(7,7)
double d6[3] = { 1, 5, 4 };//点六(5,4)
double w[3] = { 1, 1, 1 };//回归系数
double dclass[6] = { 0,0, 0, 1, 1, 1 };//指定的初始分类 点一、二、三类别为0,四、五、六类别为1
double h[6];//当前分类
double derror[6];//每步当前分类与指定分类的误差
double a = 0.01;//步长
double b0, b1, b2;//三个buffer,用来存储样本数据乘以该样本当前对应的误差 用来修正数据
for (int i = 0; i <1000; i++){//迭代1000次
h[0] = 1 / (1 + exp(-(d1[0] * w[0] + d1[1] * w[1] + d1[2] * w[2])));//sigmoid函数
h[1] = 1 / (1 + exp(-(d2[0] * w[0] + d2[1] * w[1] + d2[2] * w[2])));
h[2] = 1 / (1 + exp(-(d3[0] * w[0] + d3[1] * w[1] + d3[2] * w[2])));
h[3] = 1 / (1 + exp(-(d4[0] * w[0] + d4[1] * w[1] + d4[2] * w[2])));
h[4] = 1 / (1 + exp(-(d5[0] * w[0] + d5[1] * w[1] + d5[2] * w[2])));
h[5] = 1 / (1 + exp(-(d6[0] * w[0] + d6[1] * w[1] + d6[2] * w[2])));
for (int j = 0; j < 6; j++)
{
derror[j] = dclass[j] - h[j];
}
b0 = d1[0] * derror[0] + d2[0] * derror[1] + d3[0] * derror[2] + d4[0] * derror[3] + d5[0] * derror[4] + d6[0] * derror[5];
b1 = d1[1] * derror[0] + d2[1] * derror[1] + d3[1] * derror[2] + d4[1] * derror[3] + d5[1] * derror[4] + d6[1] * derror[5];
b2 = d1[2] * derror[0] + d2[2] * derror[1] + d3[2] * derror[2] + d4[2] * derror[3] + d5[2] * derror[4] + d6[2] * derror[5];

w[0] = w[0] + a*b0;
w[1] = w[1] + a*b1;
w[2] = w[2] + a*b2;
}

cout << "第一个权值=" << w[0] << " " << "第二个权值=" << w[1] << " 第三个权值=" << w[2] << endl << endl;//输出最佳回归系数
cout << "分界线公式:y=(-(" << w[0] << ")-(" << w[1] << ")*x)/(" << w[2] << ")" << endl;//输出分界线公式
for (int i = 0; i < 6; i++){
cout << "第" << i + 1 << "个个体的设定分类为 " << dclass[i] << " 当前分类为 " << h[i] << "当前误差为"<<derror[i]<<endl << endl;

}

system("pause");
return 1;
}