线性回归c++实现

时间:2022-08-26 10:20:40
class CPoint
{
public:
double x;
double y;

CPoint()
{
x = 0.0;
y = 0.0;
}
CPoint(double x, double y)
{
this->x = x;
this->y = y;
}
double getX()
{
return x;
}
double getY()
{
return y;
}
};


//利用线性回归模型进行预测
//y = a+bx1+cx2...(为简化计算量,设方程为y = a + bx)
//实现方法:梯度下降法

#include "CPoint.h"
#include <iostream>
#include <vector>
#include <Cmath>
using namespace std;

class LinearRegression
{
private:
double a, b;
double lasta, lastb;
const double alpha = 0.5;
public:
LinearRegression()
{
a = 0.0;
b = 0.0;
}
void GradentDescent(CPoint * p, int n)
{

do
{
lasta = a;
lastb = b;
//首先更新a
for (int i = 0; i < n; i++)
{
double hx = a + b*p[i].getX();
a = a + alpha*(p[i].getY() - hx);
}
//然后更新b
for (int i = 0; i < n; i++)
{
double hx = a + b*p[i].getX();
b = b + alpha*(p[i].getY() - hx)*p[i].getX();
}
} while (fabs(lasta - a) > 1e-3 && fabs(lastb - b) > 1e-3);//收敛条件

}
void show()
{
cout << "a: " << a << endl;
cout << "b: " << b << endl;
}
double getA()
{
return a;
}

double getB()
{
return b;
}
};

#include "LinearRegression.h"#include <windows.h>#include <math.h>#include <stdio.h>#define NUM    200 //测试数据共200个样本LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance,	PSTR szCmdLine, int iCmdShow){	static TCHAR szAppName[] = TEXT("win32");	HWND         hwnd;	MSG          msg;	WNDCLASS     wndclass;	wndclass.style = CS_HREDRAW | CS_VREDRAW;	wndclass.lpfnWndProc = WndProc;	wndclass.cbClsExtra = 0;	wndclass.cbWndExtra = 0;	wndclass.hInstance = hInstance;	wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION);	wndclass.hCursor = LoadCursor(NULL, IDC_ARROW);	wndclass.hbrBackground = (HBRUSH)GetStockObject(WHITE_BRUSH);	wndclass.lpszMenuName = NULL;	wndclass.lpszClassName = szAppName;	if (!RegisterClass(&wndclass))	{		MessageBox(NULL, TEXT("Program requires Windows NT!"),			szAppName, MB_ICONERROR);		return 0;	}	hwnd = CreateWindow(szAppName, TEXT("win32"),		WS_OVERLAPPEDWINDOW,		CW_USEDEFAULT, CW_USEDEFAULT,		CW_USEDEFAULT, CW_USEDEFAULT,		NULL, NULL, hInstance, NULL);	ShowWindow(hwnd, iCmdShow);	UpdateWindow(hwnd);	while (GetMessage(&msg, NULL, 0, 0))	{		TranslateMessage(&msg);		DispatchMessage(&msg);	}	return msg.wParam;}LRESULT CALLBACK WndProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam){	static int  cxClient, cyClient;	HDC         hdc;	double      tmp;	PAINTSTRUCT ps;	CPoint      apt[NUM];	FILE*		fp;	char		str[1024];	LinearRegression lr;	switch (message)	{	case WM_SIZE:		cxClient = LOWORD(lParam);		cyClient = HIWORD(lParam);		return 0;	case WM_PAINT:		hdc = BeginPaint(hwnd, &ps);		MoveToEx(hdc, 0, cyClient / 2, NULL);		LineTo(hdc, cxClient, cyClient / 2);		MoveToEx(hdc, cxClient / 2, 0, NULL);		LineTo(hdc, cxClient / 2, cyClient);		//读取文件		if (!(fp = fopen("C:\\Users\\Wenzhou\\Desktop\\study\\mlDATA\\data.txt", "r")))		{			printf("error");			return -1;		}				for (int i = 0; i < NUM; i++)		{			fscanf(fp, "%lf", &tmp);			fscanf(fp, "%lf", &apt[i].x);			fscanf(fp, "%lf", &apt[i].y);			apt[i].y = (apt[i].y-3)/2;		}		fclose(fp);		SelectObject(hdc, GetStockObject(BLACK_BRUSH));		//将x,y归一到0-1,根据窗口大小按比例显示于屏幕中		for (int i = 0; i < NUM; i++)		{			Ellipse(hdc, cxClient*apt[i].x - 2, cyClient - cyClient*apt[i].y - 2, cxClient*apt[i].x + 2, cyClient - cyClient*apt[i].y + 2);		}		lr.GradentDescent(apt, NUM);		//划线时同上按比例显示 原始之间方程应为 y=lr.getB()*x + lr.getA()		MoveToEx(hdc, 0, cyClient - cyClient*lr.getA(), NULL);						//(0, lr.getA())		LineTo(hdc, 1 * cxClient, cyClient - cyClient*(lr.getB() * 1 + lr.getA())); //(1, lr.getB()+lr.getA())		return 0;	case WM_DESTROY:		PostQuitMessage(0);		return 0;	}	return DefWindowProc(hwnd, message, wParam, lParam);}

运行结果:

线性回归c++实现