栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

手写字体识别

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

手写字体识别

import sys,os
import numpy as np
import matplotlib.pylab as plt
import pickle#pickle提供了一个简单的持久化功能。可以将对象以文件的形式存放在磁盘上
sys.path.append('E:deep learning  by pythondeep learn by python') # 为了导入父目录中的文件而进行的设定
from dataset.mnist import load_mnist#导入数据集
from PIL import Image
def step_function(x):
    return np.array(x > 0, dtype=np.int)

X = np.arange(-5.0, 5.0, 0.1)
Y = step_function(X)

plt.plot(X, Y)
plt.ylim(-0.1, 1.1)  # 指定图中绘制的y轴的范围
plt.show()

#sigmoid函数
def sigmoid(x):
    return 1 / (1 + np.exp(-x))    
X = np.arange(-5.0, 5.0, 0.1)
Y = sigmoid(X)
plt.plot(X, Y)
plt.ylim(-0.1, 1.1)
plt.show()

#relu函数
def relu(x):
    return np.maximum(0, x)

x = np.arange(-5.0, 5.0, 0.1)
y = relu(x)
plt.plot(x, y)
plt.ylim(-1.0, 5.5)
plt.show()

def softmax(x):
    C = np.max( x )
    exp_a = np.exp( x - C )
    sum_exp = np.sum( exp_a )
    y = exp_a / sum_exp
    return y
#进行训练
def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(0, len(x),100):#用for循环逐一取出保存在x中的图像数据
    y = predict(network, x[i])#predict函数进行分类,以numpy库数组的形式输出各个标签的对应的概率
    p= np.argmax(y)#获取概率最高的元素的索引
    if p == t[i]:
        accuracy_cnt += 1
print("Accuracy(sigmoid):" + str(100*float(accuracy_cnt) / len(x)))
###relu
def predict2(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = relu(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = relu(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(0, len(x),100):
    y = predict2(network, x[i])
    p= np.argmax(y)
    if p == t[i]:
        accuracy_cnt += 1
print("Accuracy(relu):" + str(100*float(accuracy_cnt) / len(x)))

###step_function
def predict3(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    a1 = np.dot(x, W1) + b1
    z1 = step_function(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = step_function(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y
    

x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(0, len(x),100):
    y = predict3(network, x[i])
    p= np.argmax(y)
    if p == t[i]:
        accuracy_cnt += 1
print("Accuracy(step_function):" + str(100*float(accuracy_cnt) / len(x)))

 

 

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/286170.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号