机器学习入门(五)用KNN模型,预测某一点的类别----K最近邻算法处理多元分类任务
作者:冯德平(山野雪人)
这是本人学习《深入浅出Python机器学习》(参见3.2.2 K最近邻算法处理多元分类任务)后写的一篇文章,本文用KNN建立模型,并对给出的一个点(-1.4,-1.8),给出了如何求出这个点的分类的计算实例。
from sklearn.datasets import make_blobs
#导入画图工具
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
#生成样本数为5 0 0,分类数为5的数据集
data2 =make_blobs(n_samples=480,centers=6,random_state=7)
#print(data2)
X2,y2 = data2
#print(X2,y2)
#用散点图将数据集进行可视化
plt.scatter(X2[:,0],X2 [:,1] , c=y2, cmap=plt.cm.spring,edgecolor ='k')
plt.show() #绘图1
#拟合数据:
clf = KNeighborsClassifier()
clf.fit(X2,y2)
#有一个点(-1.4,-1.8),预测这点应该分在哪一类:
print('新数据点的分类是',clf.predict(np.c_[(-1.4,-1.8)])) #预测数据
#绘图:
x_min,x_max =X2[:,0].min()-1,X2[:,0].max () + 1
y_min,y_max =X2[:,1].min()-1,X2[:,1].max() + 1
xx,yy = np.meshgrid(np.arange(x_min,x_max,.02),
np.arange(y_min,y_max,.02))
Z = clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx,yy,Z,cmap=plt.cm.Spectral) # 原文用的Pastell是错误的,而应该用Spectral
plt.scatter(X2[:,0],X2[:,1], c=y2,cmap=plt.cm.spring,edgecolor = 'k')
plt.scatter(-1.4,-1.8 , marker ='*',c ='red',s=200) #绘出要预测的点
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt. title ("K最近邻算法处理多元分类任务")
plt.show() #绘图2
‘’’
运行结果:
新数据点的分类是 [4]
‘’’



