栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

记录机器学习入门——KNN,鸢尾花分类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

记录机器学习入门——KNN,鸢尾花分类

转发大佬的作品集
感觉自己又变菜了
鸢尾花数据集下载
不要误会,这真的是我写的(应该没人会说是他写的)

# -*-coding = utf-8-*-
# @Time:2021/9/28 19:17
# @Author TG
# @File :鸢尾花分类.py
# @Software: PyCharm
import numpy as np
import operator
import pandas as pd
from os import listdir
import collections

# 1. 计算已知类别数据集中的点与当前点之间的距离;
# 2. 按照距离递增次序排序;
# 3. 选取与当前点距离最小的k个点;
# 4. 确定前k个点所在类别的出现频率;
# 5. 返回前k个点所出现频率最高的类别作为当前点的预测分类。


#数据归一化
def autoNorm(dataSet):
	#获得数据的最小值
	minVals = dataSet.min(0)
	maxVals = dataSet.max(0)
	#最大值和最小值的范围
	ranges = maxVals - minVals
	#shape(dataSet)返回dataSet的矩阵行列数
	normDataSet = np.zeros(np.shape(dataSet))
	#返回dataSet的行数
	m = dataSet.shape[0]
	#原始值减去最小值
	normDataSet = dataSet - np.tile(minVals, (m, 1))
	#除以最大和最小值的差,得到归一化数据
	normDataSet = normDataSet / np.tile(ranges, (m, 1))
	#返回归一化数据结果,数据范围,最小值
	return normDataSet, ranges, minVals


#获取归一化数据,训练集索引,测试集索引,标签
def get_data():

	df=pd.read_csv('鸢尾花数据集.csv',engine='python')
	label=list(df['species'])
	number_data=[]
	for i, r in df.iterrows():
		number_data.append([r[0], r[1], r[2], r[3]])

	data_array=np.array(number_data)
	normMat, ranges, minVals = autoNorm(data_array)#归一化数据结果,数据范围,最小值
	#训练集
	training_group1=[i for i in range(40)]
	training_group2=[i for i in range(50,90)]
	training_group=training_group1+training_group2
	test_group1=[i for i in range(40,50)]
	test_group2=[i  for i in range(90,100)]
	test_group=test_group1+test_group2
	return normMat,training_group,test_group,label

data,training_group_index,test_group_index,label=get_data()#归一化数据,训练集索引,测试集索引,标签

#获取某一组测试数据预测标签的正确性,正确为1,错误为0
def test(test_number):#test是测试数据的索引
	distant=[]
	for i in training_group_index:#计算与测试集的距离
		s=np.sum((data[test_number]-data[i])**2)**0.5
		distant.append(s)

	distant1=distant.copy()
	distant1.sort()#升序排列

	labels=[label[training_group_index[distant.index(i)]] for i in distant1]

	klabels=labels[0:10]

	most_commom = collections.Counter(klabels).most_common(1)[0][0]#collections.Counter计算最多出现的花名
	if(most_commom==label[test_number]):#比较比较出的标签和
		return 1
	else:
		return 0



#测试KNN的准确率
def	test_accuracy_ratio():
	data,training_group_index,test_group_index,label=get_data()
	s=0
	for i in test_group_index:

		s=s+test(i)
	print(s/len(test_group_index))

test_accuracy_ratio()

遇到的问题:
1、给列表复制的时候不能直接等,需要用:
distant1=distant.copy()
distant1.sort()#升序排列
2、统计前k个出现最多的花名:
可用到 collections.Counter
collections.Counter(klabels).most_common(1)[0][0]
collections.Counter

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/280752.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号