栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

线性回归--手动

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

线性回归--手动

解析解求解主要需要推导出 W 的计算公式:

y = w ∗ x + b = W ∗ X y = w * x + b = W*X y=w∗x+b=W∗X
为例,选取均方误差为损失函数:
l o s s = 1 2 n ∗ ( y − y p r e d ) 2 loss = frac{1}{2n} * (y - y_{pred})^2 loss=2n1​∗(y−ypred​)2
直接贴出推导结果(我推的太不好了):
W = ( X @ X T ) − 1 @ X @ Y W = (X@ X^T) ^{-1}@ X @ Y W=(X@XT)−1@X@Y

代码:

import numpy as np
import matplotlib.pyplot as plt
def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y

np.random.seed(10)
x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))
w = np.linalg.pinv(x @ np.transpose(x)) @ x @ y
print(w)
y_pred = w @ x
plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)
plt.show()

结果:
[3.1382164 0.78223531]

梯度下降求解以
y = w ∗ x + b = W ∗ X y = w * x + b = W*X y=w∗x+b=W∗X
为例,选取均方误差为损失函数:
l o s s = 1 2 n ∗ ( y − y p r e d ) 2 loss = frac{1}{2n} * (y - y_{pred})^2 loss=2n1​∗(y−ypred​)2
梯度计算:
∇ = 1 n ∗ ( y − W ∗ X ) ∗ X T nabla = frac{1}{n} * (y - W*X) *X^T ∇=n1​∗(y−W∗X)∗XT
利用梯度更新参数,注意梯度方向,系数更新公式:
W = W + a ∗ ∇ W = W + a * nabla W=W+a∗∇
a为学习率,不要太大,不然结果会乱跳(不收敛)

代码:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y
def monitor_mse(y, y_pred):
    Loss = ((y - y_pred) @ np.transpose(y - y_pred)) / len(y)
    return Loss
np.random.seed(10)


x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))

k = 2001
a = 0.01  # 学习率小点好,大了会乱跑
A = np.random.rand(2)

for i in range(1, k):
    y_pred = np.transpose(A) @ x
    A = A + a * ((y - y_pred) / len(y)) @ np.transpose(x)


    if i % 500 == 0:
        print(f"第 {i} 次 A:", A)
        print(f"第 {i} 次 A:", monitor_mse(y, y_pred))



plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)

plt.show()

结果:
第 500 次 A: [3.11735897 0.92539329]
第 500 次 A: 11.390507360402756
第 1000 次 A: [3.13215241 0.82385636]
第 1000 次 A: 11.385755910470582
第 1500 次 A: [3.13645339 0.79433601]
第 1500 次 A: 11.385354285166539
第 2000 次 A: [3.13770383 0.7857534 ]
第 2000 次 A: 11.385320337027098

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/769408.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号