栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

KNN实现手写字体的识别

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

KNN实现手写字体的识别

KNN实现手写字体的识别
    • KNN算法介绍:
    • 数据的导入:
      • 导入包
      • 导入数据集
    • 数据集介绍:
    • 数据集的分割:
    • 定义KNN函数:
    • 评估准确率:
    • 完整代码:

KNN算法介绍:

点击这里查看KNN算法代码及其介绍

数据的导入: 导入包
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
import pandas as pd
导入数据集
# 手写字体的数据集导入
digtis = datasets.load_digits()
target = digtis.target
data = digtis.data
数据集介绍:

数据集情况:1797条数据

data.shape, target.shape
# (1797, 64), (1797,))

对于导入的数据集data里面的每个数据的形状是(64,),我们可以将其转化为8X8像素的数据,将第一个数据进行可视化展示:
形状转换:

ima = data[0].reshape(8, 8)

Out:
array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
       [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
       [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
       [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
       [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
       [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
       [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
       [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]])

可视化:

plt.imshow(ima)

数据集的分割:
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=1)
定义KNN函数:
def knn_code(loc, k=5, order=2):  # k order是超参
    # print(order)
    diff_loc = x_train - loc
    dis_loc = np.linalg.norm(diff_loc, ord=order, axis=1)  # 没有axis得到一个数,矩阵的泛数。axis=0,得到两个数
    knn = y_train[dis_loc.argsort()[:k]]
    counts = np.bincount(knn)
    return np.argmax(counts)
评估准确率:
res = []
for i in x_test:
    res.append(knn_code(i))

acc = ((y_test == pd.Series(res))==True).sum()/len(y_test)
print("准确率:", acc)
# 准确率: 0.9944444444444445
完整代码:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# datetime:2021/11/22 22:54
# software: PyCharm
# 导入包
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
import pandas as pd

# 手写字体的数据集导入
digtis = datasets.load_digits()
target = digtis.target
data = digtis.data

# 可视化展示
ima = data[0].reshape(8, 8)
plt.imshow(ima)
plt.show()

# 数据集分割
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=1)

# KNN函数
def knn_code(loc, k=5, order=2):  # k order是超参
    # print(order)
    diff_loc = x_train - loc
    dis_loc = np.linalg.norm(diff_loc, ord=order, axis=1)  # 没有axis得到一个数,矩阵的泛数。axis=0,得到两个数
    knn = y_train[dis_loc.argsort()[:k]]
    counts = np.bincount(knn)
    return np.argmax(counts)

# acc
res = []
for i in x_test:
    res.append(knn_code(i))

acc = ((y_test == pd.Series(res))==True).sum()/len(y_test)
print("准确率:", acc)



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/618422.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号