《统计学习方法》感知机代码实现
# 感知机实现
import numpy as np
#数据
x = np.array([[3,3],[4,3],[1,1]])
label = np.array([1,1,-1])
yita = 1#步长
w = np.array([0, 0])#初始化w
b = 0#初始化b
#第一步,求出初始值时的各个点的预测值,便于观察是否误判
def y_predict(w):
y_pre = (np.dot(x,w.T)+b)*label.T
return y_pre
# print(y_pre)
k = 0
#开始进行感知机算法
while True:
k = k+1
print("第%d次" % k)
y_pre = y_predict(w)
#针对第一次的预测值,遍历是否误判
for i in range(len(y_pre)):
if y_pre[i] <= 0:#误判条件
w = w + yita * label[i]*x[i]#更新参数
b = b + yita*label[i]
print(w,b)
break#跳出for循环
if np.min(y_pre) > 0:#如果预测值列表的最小值都大于0,肯定全部大于0,这时所有点都是判断正确
print(w)
print(b)
break#跳出while循环



