简介
个人博客: https://xiaoxiablogs.top
最小二乘法就是用过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便的求得未知的数据。
一元线性回归下的最小二乘法
下面来讲解一下最小二乘法(以二维数据为例)
首先,我们得到一组数据(\(x_1,y_1\)), (\(x_2,y_2\))...(\(x_n,y_n\)),我们的预测函数 \(f(x_i)=\omega x_i+b\),也就是预测值\(\hat y_i\), 那么我们的误差的平方和为:
而我们需要使得上面的式子为最小值,从而求得我们需要的\(\omega和b\), 我们将其记作\((\omega^*, b^*)\),即
求解\(\omega\)和\(b\)的使得\(E_{(\omega, b)}=\sum^n_{i=1}(y_i-\hat y_i)^2\)的过程,称为线性回归模型的最小二乘"参数估计"
我们要求得\(E_{(\omega, b)}\)的最小值,只需要求得其极值即可。
我们可将\(E_{(\omega, b)}\)分别对\(\omega\)和\(b\)求偏导:
对上面的方程求解可以得到
其中\(\overline x=\frac1n\sum^n_{i=1}x_i\)即\(\overline x\)是\(x\)的均值
通过上面的步骤我们就可以得到最小二乘法的\(\omega和b\)了。
从而我们就可以得到关系式\(f(x_i)=\omega x_i+b\)
多元线性回归下的最小二乘法
同样的,如果将最小二乘法应用到\(n\)维数据中
我们的数据\(x\)如下:
对应的\(\omega\)为\(\left(\begin{matrix}\omega_1&\omega_2&\dots&\omega_n\end{matrix}\right)\),所对应的方程为
\(\omega_1 x_1+\omega_2 x_2+\dots+\omega_n x_n+b\)
为了方便计算,我们可以将\(b\)放在\(x\)和\(\omega\)中,即将\(b\)作为一维,其为固定值1,参数为\(\omega_b\)
因此,我们的方程就变为了\(f(x_i) = \omega^Tx\)
与上方一元线性回归下的误差类似地
令\(E_\omega=(y-X\omega)^T(y-X\omega)\),对\(\omega\)求导可得:
令上式等于零可得\(\omega\)的最优解:
从而可以得到我们需要的函数\(f(x_i)=\omega'^Tx'=\omega^Tx+\omega_b=\omega^Tx+b\)也就是\(f(x_i)=x_i^T(X^TX)^{-1}X^Ty\)
多元最小二乘法也是用与一元线性回归
最小二乘法的代码实现
def LeastSquareMethod(X, Y):
"""
最小二乘法
:param X: 未进行扩展的X矩阵
:param Y: X矩阵相对应的结果集矩阵
:return X_b: 进行扩展处理后的X矩阵
:return omega: 使用最小二乘法求得的w
"""
# 对X矩阵进行扩展
X_b = np.c_[np.ones((len(X), 1)), X]
'''
np.linalg.inv用来求矩阵的逆矩阵
dot表示矩阵祥恒
T表示矩阵的转置
'''
omega = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(Y)
return X_b, omega
实例
import numpy as np
import matplotlib.pyplot as plt
def LeastSquareMethod(X, Y):
"""
最小二乘法
:param X: 未进行扩展的X矩阵
:param Y: X矩阵相对应的结果集矩阵
:return X_b: 进行扩展处理后的X矩阵
:return omega: 使用最小二乘法求得的w
"""
# 对X矩阵进行扩展
X_b = np.c_[np.ones((len(X), 1)), X]
'''
np.linalg.inv用来求矩阵的逆矩阵
dot表示矩阵祥恒
T表示矩阵的转置
'''
omega = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(Y)
return X_b, omega
if __name__ == '__main__':
X = np.random.rand(100, 1)
Y = 4 + 3 * X + np.random.randn(100, 1)
X_b, omega = LeastSquareMethod(X, Y)
Y2 = X_b.dot(omega)
plt.plot(X, Y, 'o')
plt.plot(X, Y2, 'r')
plt.show()
得到的图像为