栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【机器学习】02 两个参数的梯度下降 (详细注释+动态训练效果图)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【机器学习】02 两个参数的梯度下降 (详细注释+动态训练效果图)

前言

在上一篇文章中【机器学习】01 梯度下降(详细注释+动态训练效果图)我是假设蘑菇的大小和毒性成正比例关系。

不过这样假设太过于简单,这一次我们就增加一点难度,假设两者之间是成一次函数关系。
也就是 y = w x + b y=wx+b y=wx+b。那么现在我们就要通过梯度下降得到w和b两个参数的值。
梯度下降的方法和前面说的一样,就只要把对w和对b的偏导求出来就ok了!

下面就是梯度下降的核心代码

dw = 2*w*x**2+2*b*x-2*x*y
db=2*b+2*w*x-2*y
 # 现在不能叫斜率了,应该叫偏导
alpha = 0.1
# 学习率
w = w - alpha * dw
b=b-alpha*db
下面开始正式的梯度下降 生成数据集

按照 y = 0.5 x + 0.4 y=0.5x+0.4 y=0.5x+0.4生成数据

import numpy as np
def get_mushroom(counts):
    xs = np.random.rand(counts)
    # 生成counts个x
    xs = np.sort(xs)
    # 将x们按顺序排列
    ys = np.array([0.5 * x + np.random.rand() / 10 + 0.4 for x in xs])
    # 按照y=0.5x+0.4生成y,并加入一些小偏差
    return xs, ys

设初始的预测函数为 y = 1.2 x − 0.5 y=1.2x-0.5 y=1.2x−0.5

# 随便给出一个预测函数
w = 1.2
b=-0.5
y_pre = w * xs +b

画一个蘑菇的大小和毒性的图,散点图是原始数据,绿色的直线是一开始的预测函数。

梯度下降

下面就是梯度下降以及画动态效果图的代码了

for _ in range(50):
    for i in range(beans_num):
        x = xs[i]
        y = ys[i]
        # 斜率为代价函数对参数求导得来的
        # k = 2 * (x ** 2) * w -2 * x * y
        dw = 2*w*x**2+2*b*x-2*x*y
        db=2*b+2*w*x-2*y
        # 现在不能叫斜率了,应该叫偏导
        alpha = 0.1
        # 学习率
        w = w - alpha * dw
        b=b-alpha*db

    # 下面是利用 matplotlib 来画一个动态图,画出参数不断调整的过程
    plt.clf()
    # 清空窗口
    plt.scatter(xs, ys,color='moccasin')
    y_pre = w * xs+b
    # print(y_pre)
    plt.plot(xs, y_pre, color='crimson')
    es = (ys - y_pre) ** 2
    avg_e = np.sum(es) / beans_num
    plt.xlim(0,1)
    plt.ylim(0,1.2)
    plt.text(0, 0.87, "误差:",color="mediumturquoise",font='STSong',fontsize=14)
    plt.text(0.1, 0.87, avg_e,color="mediumblue",font='STSong',fontsize=15)
    plt.text(0.03,1.13,"w:",color="teal",font='STSong',fontsize=14)
    plt.text(0.1,1.13,w,color="teal",font='STSong',fontsize=14)
    plt.text(0.03, 1, "b:", color="teal", font='STSong', fontsize=14)
    plt.text(0.1, 1, b, color="teal", font='STSong', fontsize=14)
    plt.pause(0.01)

plt.show()

下面是动态训练效果

完整代码
import numpy as np
from matplotlib import pyplot as plt
import dataset

beans_num = 100
xs, ys = dataset.get_mushroom(beans_num)
# print(xs)
# print(ys)
# 画一个散点图,看看豆豆的大小和毒性大致是一个什么关系
plt.scatter(xs, ys)
# 画个标题和横纵坐标
plt.title("Size-Toxicity Function", fontsize=16)
plt.xlabel("Mushroom Size", fontsize=14)
plt.ylabel("Toxicity", fontsize=14)
# plt.show()

# 随便给出一个预测函数
w = 1.2
b=-0.5
y_pre = w * xs +b
# print(y_pre)
plt.plot(xs, y_pre, color='green')
plt.show()

# 上面还是错的离谱,下面使用梯度下降算法来修正w

# 普通随机下降
for _ in range(50):
    for i in range(beans_num):
        x = xs[i]
        y = ys[i]
        # 斜率为代价函数对参数求导得来的
        # k = 2 * (x ** 2) * w -2 * x * y
        dw = 2*w*x**2+2*b*x-2*x*y
        db=2*b+2*w*x-2*y
        # 现在不能叫斜率了,应该叫偏导
        alpha = 0.1
        # 学习率
        w = w - alpha * dw
        b=b-alpha*db

    # 下面是利用 matplotlib 来画一个动态图,画出参数不断调整的过程
    plt.clf()
    # 清空窗口
    plt.scatter(xs, ys,color='moccasin')
    y_pre = w * xs+b
    # print(y_pre)
    plt.plot(xs, y_pre, color='crimson')
    es = (ys - y_pre) ** 2
    avg_e = np.sum(es) / beans_num
    plt.xlim(0,1)
    plt.ylim(0,1.2)
    plt.text(0, 0.87, "误差:",color="mediumturquoise",font='STSong',fontsize=14)
    plt.text(0.1, 0.87, avg_e,color="mediumblue",font='STSong',fontsize=15)
    plt.text(0.03,1.13,"w:",color="teal",font='STSong',fontsize=14)
    plt.text(0.1,1.13,w,color="teal",font='STSong',fontsize=14)
    plt.text(0.03, 1, "b:", color="teal", font='STSong', fontsize=14)
    plt.text(0.1, 1, b, color="teal", font='STSong', fontsize=14)
    plt.pause(0.01)

plt.show()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/744855.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号