线性回归c++实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wz2671/article/details/77940621
class CPoint
{
public:
	double x;
	double y;

	CPoint()
	{
		x = 0.0;
		y = 0.0;
	}
	CPoint(double x, double y)
	{
		this->x = x;
		this->y = y;
	}
	double getX()
	{
		return x;
	}
	double getY()
	{
		return y;
	}
};

 

 

//利用线性回归模型进行预测
//y = a+bx1+cx2...(为简化计算量,设方程为y = a + bx)
//实现方法:梯度下降法

#include "CPoint.h"
#include <iostream>
#include <vector>
#include <Cmath>
using namespace std;

class LinearRegression
{
private:
	double a, b;
	double lasta, lastb;
	const double alpha = 0.5;
public:
	LinearRegression()
	{
		a = 0.0;
		b = 0.0;
	}
	void GradentDescent(CPoint * p, int n)
	{

		do
		{
			lasta = a;
			lastb = b;
			//首先更新a
			for (int i = 0; i < n; i++)
			{
				double hx = a + b*p[i].getX();
				a = a + alpha*(p[i].getY() - hx);
			}
			//然后更新b
			for (int i = 0; i < n; i++)
			{
				double hx = a + b*p[i].getX();
				b = b + alpha*(p[i].getY() - hx)*p[i].getX();
			}
		} while (fabs(lasta - a) > 1e-3 && fabs(lastb - b) > 1e-3);//收敛条件
		
	}
	void show()
	{
		cout << "a: " << a << endl;
		cout << "b: " << b << endl;
	}
	double getA()
	{
		return a;
	}

	double getB()
	{
		return b;
	}
};

 

#include "LinearRegression.h"
#include <windows.h>
#include <math.h>
#include <stdio.h>

#define NUM    200 //测试数据共200个样本

LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);

int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance,
	PSTR szCmdLine, int iCmdShow)
{
	static TCHAR szAppName[] = TEXT("win32");
	HWND         hwnd;
	MSG          msg;
	WNDCLASS     wndclass;

	wndclass.style = CS_HREDRAW | CS_VREDRAW;
	wndclass.lpfnWndProc = WndProc;
	wndclass.cbClsExtra = 0;
	wndclass.cbWndExtra = 0;
	wndclass.hInstance = hInstance;
	wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION);
	wndclass.hCursor = LoadCursor(NULL, IDC_ARROW);
	wndclass.hbrBackground = (HBRUSH)GetStockObject(WHITE_BRUSH);
	wndclass.lpszMenuName = NULL;
	wndclass.lpszClassName = szAppName;

	if (!RegisterClass(&wndclass))
	{
		MessageBox(NULL, TEXT("Program requires Windows NT!"),
			szAppName, MB_ICONERROR);
		return 0;
	}

	hwnd = CreateWindow(szAppName, TEXT("win32"),
		WS_OVERLAPPEDWINDOW,
		CW_USEDEFAULT, CW_USEDEFAULT,
		CW_USEDEFAULT, CW_USEDEFAULT,
		NULL, NULL, hInstance, NULL);

	ShowWindow(hwnd, iCmdShow);
	UpdateWindow(hwnd);

	while (GetMessage(&msg, NULL, 0, 0))
	{
		TranslateMessage(&msg);
		DispatchMessage(&msg);
	}
	return msg.wParam;
}

LRESULT CALLBACK WndProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam)
{
	static int  cxClient, cyClient;
	HDC         hdc;
	double      tmp;
	PAINTSTRUCT ps;
	CPoint      apt[NUM];
	FILE*		fp;
	char		str[1024];
	LinearRegression lr;

	switch (message)
	{
	case WM_SIZE:
		cxClient = LOWORD(lParam);
		cyClient = HIWORD(lParam);
		return 0;

	case WM_PAINT:
		hdc = BeginPaint(hwnd, &ps);

		MoveToEx(hdc, 0, cyClient / 2, NULL);
		LineTo(hdc, cxClient, cyClient / 2);

		MoveToEx(hdc, cxClient / 2, 0, NULL);
		LineTo(hdc, cxClient / 2, cyClient);

		//读取文件
		if (!(fp = fopen("data.txt", "r")))
		{
			printf("error");
			return -1;
		}
		
		for (int i = 0; i < NUM; i++)
		{
			fscanf(fp, "%lf", &tmp);
			fscanf(fp, "%lf", &apt[i].x);
			fscanf(fp, "%lf", &apt[i].y);
			apt[i].y = (apt[i].y-3)/2;
		}

		fclose(fp);

		SelectObject(hdc, GetStockObject(BLACK_BRUSH));
		//将x,y归一到0-1,根据窗口大小按比例显示于屏幕中
		for (int i = 0; i < NUM; i++)
		{
			Ellipse(hdc, cxClient*apt[i].x - 2, cyClient - cyClient*apt[i].y - 2, cxClient*apt[i].x + 2, cyClient - cyClient*apt[i].y + 2);
		}

		lr.GradentDescent(apt, NUM);
		//划线时同上按比例显示 原始之间方程应为 y=lr.getB()*x + lr.getA()
		MoveToEx(hdc, 0, cyClient - cyClient*lr.getA(), NULL);						//(0, lr.getA())
		LineTo(hdc, 1 * cxClient, cyClient - cyClient*(lr.getB() * 1 + lr.getA())); //(1, lr.getB()+lr.getA())

		return 0;

	case WM_DESTROY:
		PostQuitMessage(0);
		return 0;
	}
	return DefWindowProc(hwnd, message, wParam, lParam);
}


运行结果:

 

 

数据集以及源码可在链接https://download.csdn.net/download/wz2671/11129334中进行下载。

展开阅读全文

没有更多推荐了,返回首页