栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

线性可分:感知机

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

线性可分:感知机

感知机思想:错误驱动
模型:
f ( x ) = s i g n ( W T x ) f(x)=sign(W^Tx) f(x)=sign(WTx), x ∈ R p xin R^p x∈Rp, W ∈ R p Win R^p W∈Rp
s i g n ( a ) = { + 1 , a ≥ 0 − 1 , a < 0 sign(a)=left{begin{matrix} +1,ageq0 \ -1,a<0 end{matrix}right. sign(a)={+1,a≥0−1,a<0​

前提:数据是线性可分的
样本集: { x i , y i } i = 1 N {x_i,y_i}_{i=1}^{N} {xi​,yi​}i=1N​
先给 W W W一个初始值 W 0 W_0 W0​
D : 被 错 误 分 类 的 样 本 D:{被错误分类的样本} D:被错误分类的样本
策略:
loss function:被错误分类的点的个数
L ( W ) = ∑ i = 1 N I { y i W T x i < 0 } L(W)=sumlimits_{i=1}^{N}I{y_iW^Tx_i<0} L(W)=i=1∑N​I{yi​WTxi​<0}
当样本点被正确分类时: y i W T x i > 0 y_iW^Tx_i>0 yi​WTxi​>0
W T x i > 0 W^Tx_i>0 WTxi​>0时, y i = + 1 y_i=+1 yi​=+1
W T x i < 0 W^Tx_i<0 WTxi​<0时, y i = − 1 y_i=-1 yi​=−1
那么样本点被错误分类时, y i W T x i < 0 y_iW^Tx_i<0 yi​WTxi​<0
但是此时 L ( W ) L(W) L(W)不可导,所以这个损失函数不合适。所以改用以下的损失函数

L ( W ) = ∑ x i ∈ D − y i W T x i L(W)=sumlimits_{x_iin D} -y_iW^Tx_i L(W)=xi​∈D∑​−yi​WTxi​
在代码的时候可以用随机梯度下降优化

实验数据:

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib 
import matplotlib.pyplot as plt

data = pd.read_csv('data/9-2-data.csv')
dataMat = np.mat(data.iloc[:,:2].values)
labelMat = np.mat(data.iloc[:,-1].values).T
m, n = np.shape(dataMat)
w = np.zeros((1, np.shape(dataMat)[1]))
#初始化偏置b为0
b = 0
#初始化步长,也就是梯度下降过程中的n,控制梯度下降速率
h = 0.0001
for k in range(50):
        #对于每一个样本进行梯度下降
        #李航书中在2.3.1开头部分使用的梯度下降,是全部样本都算一遍以后,统一
        #进行一次梯度下降
        #在2.3.1的后半部分可以看到(例如公式2.6 2.7),求和符号没有了,此时用
        #的是随机梯度下降,即计算一个样本就针对该样本进行一次梯度下降。
        #两者的差异各有千秋,但较为常用的是随机梯度下降。
        for i in range(m):
            #获取当前样本的向量
            xi = dataMat[i]
            #获取当前样本所对应的标签
            yi = labelMat[i]
            #判断是否是误分类样本
            #误分类样本特诊为: -yi(w*xi+b)>=0,详细可参考书中2.2.2小节
            #在书的公式中写的是>0,实际上如果=0,说明改点在超平面上,也是不正确的
            if -1 * yi * (w * xi.T + b) >= 0:
                #对于误分类样本,进行梯度下降,更新w和b
                w = w + h *  yi * xi
                b = b + h * yi
sns.set(style='whitegrid')
sns.scatterplot(x='x1',y='x2',hue='y',data=data,)
x=np.linspace(-1,3)
plt.plot(x,-(w.tolist()[0][0]/w.tolist()[0][1])*x+b.tolist()[0][0])
plt.show()

画图结果:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/308624.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号