栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

机器学习-------线性回归

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

机器学习-------线性回归

线性回归

1.给定数据集
D = ( x 1 , y 1 ) , ( x 2 , y 2 ) , ( x 3 , y 3 ) , ( x 4 , y 4 ) … … ( x n , y n ) D={(x_1,y_1),(x_2,y_2),(x_3,y_3),(x_4,y_4)}……(x_n,y_n) D=(x1​,y1​),(x2​,y2​),(x3​,y3​),(x4​,y4​)……(xn​,yn​)
希望可以得到
f ( x i ) = w x i + b f(x_i)=wx_i+b f(xi​)=wxi​+b
使得 f ( x i ) f(x_i) f(xi​)与 y i y_i yi​之间的差别尽可能小,这时我们引入损失函数
l o s s ( w , b ) = ∑ i = 1 n ( f ( x i ) − y i 2 ) loss(w,b)=sum_{i=1}^{n}(f(x_i)-y_i^2) loss(w,b)=i=1∑n​(f(xi​)−yi2​)
此损失函数是基于均方差来构造的,可通过最小二乘法来进行求解。

1.一元线性回归

可直接对w和b求偏导,让其导数为零,即可求解出w,b

2.多元线性回归

x为特征矩阵,y为标签矩阵。假设函数为
f ( x i ) = w 1 ∗ x 1 + w 2 ∗ x 2 + w 3 ∗ x 3 + … … + w i ∗ x i + w 0 f(x_i)=w_1*x_1+w_2*x_2+w_3*x_3+……+w_i*x_i+w_0 f(xi​)=w1​∗x1​+w2​∗x2​+w3​∗x3​+……+wi​∗xi​+w0​
为了计算方便,我们给x添加一个特征1,

则 f = W X f=WX f=WX

损失函数为 l o s s = ∑ i = 1 m ( x i ∗ W − y i ) 2 loss=sum_{i=1}^{m}(x_i*W-y_i)^2 loss=∑i=1m​(xi​∗W−yi​)2

该函数可写成矩阵相乘形式,对矩阵进行求导,可解得
w = ( X T X ) − 1 X Y w=(X^TX)^{-1}XY w=(XTX)−1XY
详细证明正在学习中,之后在更新

例子:希望通过分析pizza半径与价格的关系,来预测任意半径pizza的价格

#导入必要的模块
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from  sklearn import metrics
#导入数据
pizza=pd.read_csv("pizza.csv",index_col='ld')
pizza.head()
dia=pizza.loc[:,'Diameter'].values
price=pizza.loc[:,'Price'].values

print(dia)
print(price)
dia_new=dia.T
print(dia_new.shape)
print(price.shape)
[ 6  8 10 14 18]
[ 7.   9.  13.  17.5 18. ]
(5,)
(5,)
dia_new=np.mat(dia).reshape(-1,1)
print(dia_new.shape)
#打印散点图
plt.scatter(dia,price)
plt.show()

(5, 1)

#将dia矩阵添加一个 为 1 的特征
one_colum=np.ones((dia_new.shape[0],1))
print(one_colum.shape)
x_new=np.concatenate((one_colum,dia_new),axis=1)
print(x_new)
(5, 1)
[[ 1.  6.]
 [ 1.  8.]
 [ 1. 10.]
 [ 1. 14.]
 [ 1. 18.]]

w = ( X T X ) − 1 X Y w=(X^TX)^{-1}XY w=(XTX)−1XY

#手动计算最小二乘法
theta=np.dot(np.dot(np.linalg.inv(np.dot(x_new.T,x_new)),x_new.T),price)
print(theta)
[[1.96551724 0.9762931 ]]
#导入模型
model=LinearRegression()
#训练模型
model.fit(dia_new,price)
#打印w
print(model.coef_)
#打印b
print(model.intercept_)
[0.9762931]
1.965517241379315

对一个模型好坏的判断有MSE(均方误差),RMSE(均方根误差),决定系数 R 2 R^2 R2等。
M S E = 1 n s u m i = 1 m w i ( y i − y i ^ ) 2 MSE=frac{1}{n}_sum_{i=1}^{m}w_i(y_i-hat{y_i})^2 MSE=n1​s​umi=1m​wi​(yi​−yi​^​)2

R M S E = M S E RMSE=sqrt{MSE} RMSE=MSE ​

R 2 = 1 − M S E v a r ( y ⃗ ) R^2=1-frac{MSE}{var(vec{y})} R2=1−var(y ​)MSE​

R 2 R^2 R2越接大,拟合效果越好,越小则拟合效果越差

#预测价格
predict_price=model.predict(dia_new)
#mse
mse=metrics.mean_squared_error(price,predict_price)
print('MSE: ',mse)
#rmse
rmse=np.sqrt(mse)
print('RMSE:',rmse)
#r2
r2=metrics.r2_score(price,predict_price)
print('r2: ',r2)
MSE:  1.7495689655172406
RMSE: 1.3227127297781784
r2:  0.9100015964240102
#w
w=model.coef_
#截距
b=model.intercept_
x0=min(dia)
x1=max(dia)
y0=w*x0+b
y1=w*x1+b
plt.scatter(dia,price)
plt.plot([x0,x1],[y0,y1],c='r',alpha=0.5)
plt.show()


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580650.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号