class CPoint
{
public:
double x;
double y;
CPoint()
{
x = 0.0;
y = 0.0;
}
CPoint(double x, double y)
{
this->x = x;
this->y = y;
}
double getX()
{
return x;
}
double getY()
{
return y;
}
};
//利用线性回归模型进行预测
//y = a+bx1+cx2...(为简化计算量,设方程为y = a + bx)
//实现方法:梯度下降法
#include "CPoint.h"
#include <iostream>
#include <vector>
#include <Cmath>
using namespace std;
class LinearRegression
{
private:
double a, b;
double lasta, lastb;
const double alpha = 0.5;
public:
LinearRegression()
{
a = 0.0;
b = 0.0;
}
void GradentDescent(CPoint * p, int n)
{
do
{
lasta = a;
lastb = b;
//首先更新a
for (int i = 0; i < n; i++)
{
double hx = a + b*p[i].getX();
a = a + alpha*(p[i].getY() - hx);
}
//然后更新b
for (int i = 0; i < n; i++)
{
double hx = a + b*p[i].getX();
b = b + alpha*(p[i].getY() - hx)*p[i].getX();
}
} while (fabs(lasta - a) > 1e-3 && fabs(lastb - b) > 1e-3);//收敛条件
}
void show()
{
cout << "a: " << a << endl;
cout << "b: " << b << endl;
}
double getA()
{
return a;
}
double getB()
{
return b;
}
};
#include "LinearRegression.h"#include <windows.h>#include <math.h>#include <stdio.h>#define NUM 200 //测试数据共200个样本LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, PSTR szCmdLine, int iCmdShow){ static TCHAR szAppName[] = TEXT("win32"); HWND hwnd; MSG msg; WNDCLASS wndclass; wndclass.style = CS_HREDRAW | CS_VREDRAW; wndclass.lpfnWndProc = WndProc; wndclass.cbClsExtra = 0; wndclass.cbWndExtra = 0; wndclass.hInstance = hInstance; wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION); wndclass.hCursor = LoadCursor(NULL, IDC_ARROW); wndclass.hbrBackground = (HBRUSH)GetStockObject(WHITE_BRUSH); wndclass.lpszMenuName = NULL; wndclass.lpszClassName = szAppName; if (!RegisterClass(&wndclass)) { MessageBox(NULL, TEXT("Program requires Windows NT!"), szAppName, MB_ICONERROR); return 0; } hwnd = CreateWindow(szAppName, TEXT("win32"), WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, hInstance, NULL); ShowWindow(hwnd, iCmdShow); UpdateWindow(hwnd); while (GetMessage(&msg, NULL, 0, 0)) { TranslateMessage(&msg); DispatchMessage(&msg); } return msg.wParam;}LRESULT CALLBACK WndProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam){ static int cxClient, cyClient; HDC hdc; double tmp; PAINTSTRUCT ps; CPoint apt[NUM]; FILE* fp; char str[1024]; LinearRegression lr; switch (message) { case WM_SIZE: cxClient = LOWORD(lParam); cyClient = HIWORD(lParam); return 0; case WM_PAINT: hdc = BeginPaint(hwnd, &ps); MoveToEx(hdc, 0, cyClient / 2, NULL); LineTo(hdc, cxClient, cyClient / 2); MoveToEx(hdc, cxClient / 2, 0, NULL); LineTo(hdc, cxClient / 2, cyClient); //读取文件 if (!(fp = fopen("C:\\Users\\Wenzhou\\Desktop\\study\\mlDATA\\data.txt", "r"))) { printf("error"); return -1; } for (int i = 0; i < NUM; i++) { fscanf(fp, "%lf", &tmp); fscanf(fp, "%lf", &apt[i].x); fscanf(fp, "%lf", &apt[i].y); apt[i].y = (apt[i].y-3)/2; } fclose(fp); SelectObject(hdc, GetStockObject(BLACK_BRUSH)); //将x,y归一到0-1,根据窗口大小按比例显示于屏幕中 for (int i = 0; i < NUM; i++) { Ellipse(hdc, cxClient*apt[i].x - 2, cyClient - cyClient*apt[i].y - 2, cxClient*apt[i].x + 2, cyClient - cyClient*apt[i].y + 2); } lr.GradentDescent(apt, NUM); //划线时同上按比例显示 原始之间方程应为 y=lr.getB()*x + lr.getA() MoveToEx(hdc, 0, cyClient - cyClient*lr.getA(), NULL); //(0, lr.getA()) LineTo(hdc, 1 * cxClient, cyClient - cyClient*(lr.getB() * 1 + lr.getA())); //(1, lr.getB()+lr.getA()) return 0; case WM_DESTROY: PostQuitMessage(0); return 0; } return DefWindowProc(hwnd, message, wParam, lParam);}
运行结果: