栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Svm支持向量机代码实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Svm支持向量机代码实现

import numpy as np

#我们获取x和y,分别为x为特征矩阵(3, 10),有3个数据,每个数据10个特征,y为类别向量(3,),有三个数据,每个数据对应一个类别,
#数据的特征和类别在x和y上的索引是对应的
cat  = np.array([0.5,2.4,3.2,2.4,5.5,3.5,2.2,1.4,2.0,4.0])
dog  = np.array([3.4,1.0,0.1,1.8,6.4,1.2,2.0,1.1,5.2,3.4])
frog = np.array([3.3,2.0,1.1,1.4,1.1,5.4,5.5,5.3,1.2,1.8])
y = np.array([0,1,2])
x = np.array([cat,dog,frog])
# [[0.5 2.4 3.2 2.4 5.5 3.5 2.2 1.4 2.  4. ]
#  [3.4 1.  0.1 1.8 6.4 1.2 2.  1.1 5.2 3.4]
#  [3.3 2.  1.1 1.4 1.1 5.4 5.5 5.3 1.2 1.8]]

def init_para(shape0,shape1):
    #w和b为我们初始化的参数,w为一个矩阵,行为特征数,列为类别数,b为类别数的多元向量
    #rand为对w和b的每个值赋予服从于均方分布的值,randn为正态分布
    w,b = np.random.rand(shape0,shape1),np.random.rand(shape1)
    # 返回一个0-1的随机数
    # np.random.rand(shape0,shape1)s0为个数,s1为维度
    return w,b#(10,3)

def get_loss(x,y):
    w,b = init_para(x.shape[1],y.shape[0])#(10,3)w=10行3列,b=3行1列,    w=x的列数,b=y的行数
    loss_SVM,target = 0,[]
    for i in range(y.shape[0]):#y的行数=3
        #s是x和w的矩阵乘加上b,s为(3,)向量,可以理解为将x的10维特征通过矩阵乘的方式转换为3元的分类向量
        s = np.matmul(x[i],w) + b
        #计算svm损失的公式,s_y即将syi就是正确的那个类别对应的s值,转换为3元向量,即乘以一个shape为(3,)的全为1的向量
        s_y = s[y[i]] * np.ones(y.shape[0])#    [0,1,2]*3个数的一维数组(=3)
        #计算svm损失
        loss_i = s - s_y + np.ones(y.shape[0])   #   +3个数的一维数组
        #loss_i>0时实行max操作
        loss_SVM += np.sum(loss_i[loss_i > 0]) - 1

        #分类,即对分类向量的每个值进行谁大选谁的操作
        index = np.where(s == np.max(s))[0]
        target.append(index[0])
    #因为有多个数据点,所以需要求每个数据点的平均SVM损失
    loss_SVM = loss_SVM / y.shape[0]
    print(loss_SVM,w,b,target)
    print()
    print(s,s_y,loss_i)
    return loss_SVM,w,b,target

def opt(num,x,y,k):
    #优化为我们制造num个W和b,对应有num个s向量,对应能算出num个SVM损失函数值
    W_list,b_list,loss_list,target_list=[],[],[],[]
    for i in range(num):
        loss,W,b,target = get_loss(x,y)
        if i % 500 == 0:
            print('第%s次,loss为:%s' % (i,loss))
        W_list.append(W)
        b_list.append(b)
        loss_list.append(loss)
        target_list.append(target)
    #求取这num个w和b中svm损失函数值最小的k个值对应的索引,即argsort是显示排序后的索引向量
    loss_list = np.array(loss_list).argsort()
    #loss_list[:k]代表loss_list的前K个索引,也就是前k个最小的损失值对应的索引,然后返回这些索引对应的w和b和target
    W_k = np.array(W_list)[loss_list[:k]]
    b_k = np.array(b_list)[loss_list[:k]]
    target_k = np.array(target_list)[loss_list[:k]]
    return W_k,b_k,target_k

if __name__ == "__main__":
    get_loss(x,y)
    # W,b,target = opt(85000,x,y,3)
    # print("误差最小的%s个W:" % 3,W)
    # print("误差最小的%s个b:" % 3,b)
    # print("误差最小的%s个分类标签:" % 3,target)

第83500次,loss为:3.343499357813641
第84000次,loss为:6.181845615536322
第84500次,loss为:5.8686677536000404
误差最小的3个W: [[[0.32844404 0.7163863  0.10854947]
  [0.07754455 0.54069304 0.15794544]
  [0.90056409 0.37977735 0.13973564]
  [0.54325316 0.61890981 0.36994007]
  [0.96696634 0.71622404 0.20797589]
  [0.76120027 0.34579976 0.79569331]
  [0.42550202 0.13230416 0.80459579]
  [0.04706553 0.68142314 0.35390186]
  [0.12922677 0.92854093 0.70876595]
  [0.36000685 0.21628632 0.71062637]]

 [[0.62378339 0.76662255 0.72736281]
  [0.37096624 0.25489383 0.67981976]
  [0.87028916 0.14178342 0.50992209]
  [0.65072771 0.72093656 0.64541691]
  [0.81941726 0.62474984 0.49126317]
  [0.39156395 0.249462   0.24419969]
  [0.52694055 0.79230952 0.87783305]
  [0.57168218 0.46794002 0.79818214]
  [0.06980084 0.48980119 0.51695222]
  [0.9634186  0.9383427  0.18663827]]

 [[0.27688008 0.96886982 0.97909507]
  [0.37407184 0.62133562 0.05686489]
  [0.92485235 0.21030799 0.02227001]
  [0.38108974 0.23170431 0.98025494]
  [0.72946106 0.79353999 0.43862167]
  [0.06697129 0.08609185 0.0977768 ]
  [0.12470004 0.65315601 0.55706641]
  [0.85267869 0.34243389 0.92206649]
  [0.32069096 0.86837812 0.75928898]
  [0.7772152  0.1650428  0.08815647]]]
误差最小的3个b: [[0.80485332 0.13332358 0.72415515]
 [0.72241767 0.4545099  0.81502036]
 [0.61183262 0.82468408 0.66842231]]
误差最小的3个分类标签: [[0 1 2]
 [0 1 2]
 [0 1 2]]
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/293669.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号