栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

如何使用Tensorflow创建预测和地面真实标签的混淆矩阵?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

如何使用Tensorflow创建预测和地面真实标签的混淆矩阵?

这段代码对我有用。我自己整理一下:)

from sklearn.metrics import precision_recall_fscore_support as scorefrom sklearn.metrics import classification_reportdef print_confusion_matrix(plabels,tlabels):"""    functions print the confusion matrix for the different classes    to find the error...    Input:    -----------    plabels: predicted labels for the classes...    tlabels: true labels for the classes    pre from: http://stackoverflow.com/questions/2148543/how-to-write-a-confusion-matrix-in-python"""import pandas as pdplabels = pd.Series(plabels)tlabels = pd.Series(tlabels)# draw a cross tabulation...df_confusion = pd.crosstab(tlabels,plabels, rownames=['Actual'], colnames=['Predicted'], margins=True)#print df_confusionreturn df_confusiondef confusionMatrix(text,Labels,y_pred, not_partial):    y_actu = np.where(Labels[:]==1)[1]    df = print_confusion_matrix(y_pred,y_actu)    print "n",df    #print plt.imshow(df.as_matrix())    if not_partial:       print "n",classification_report(y_actu, y_pred)    print "nt------------------------------------------------------n"def do_eval(message, sess, correct_prediction, accuracy, pred, X_, y_,x,y):    predictions = sess.run([correct_prediction], feed_dict={x: X_, y: y_})    prediction  = tf.argmax(pred,1)    labels = prediction.eval(feed_dict={x: X_, y: y_}, session=sess)    print message, accuracy.eval({x: X_, y: y_}),"n"    confusionMatrix("Partial Confusion matrix",y_,predictions[0], False)#Partial confusion Matrix    confusionMatrix("Complete Confusion matrix",y_,labels, True) #complete confusion Matrix# Launch the graphwith tf.Session() as sess:sess.run(init)data = zip(X_train,y_train)data = np.array(data)data_size = len(data)num_batches_per_epoch = int(len(data)/batch_size) + 1for epoch in range(training_epochs):    avg_cost = 0.    # Shuffle the data at each epoch    shuffle_indices = np.random.permutation(np.arange(data_size))    shuffled_data = data[shuffle_indices]    for batch_num in range(num_batches_per_epoch):        start_index = batch_num * batch_size        end_index = min((batch_num + 1) * batch_size, data_size)        sample = zip(*shuffled_data[start_index:end_index])        #picking up random batches from training set of specific size        batch_xs, batch_ys = sample[0],sample[1]        # Fit training using batch data        sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})        # Compute average loss        avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys})/num_batches_per_epoch    #append loss    loss_history.append(avg_cost)    # Display logs per epoch step    if (epoch % display_step == 0):        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))         # Calculate training  accuracy        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))        trainAccuracy = accuracy.eval({x: X_train, y: y_train})        train_acc_history.append(trainAccuracy)        # Calculate validation  accuracy        valAccuracy = accuracy.eval({x: X_val, y: y_val})        val_acc_history.append(valAccuracy)         print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost), "train=",trainAccuracy,"val=", valAccuracyprint "Optimization Finished!n"# evaluation of  modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # Calculate accuracyaccuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))do_eval("Accuracy of Gold Test set Results: ", sess, correct_prediction, accuracy, pred, X_gold, y_gold, x, y)

这是示例输出:

Accuracy of Gold Test set Results:  0.642608Predicted  False  True  AllActual          0  20    46   661   3     1    42  21     1   223   8     4   124  16     7   235  54   259  3136  41    14   557  11     2   138  48    94  1429  29     4   3310 17     4   2111 39   116  155All          307   552  859Predicted   0  1  2   3   4    5   6   7    8   9  10   11  AllActual  0          46  0  0   0   0    8   0   2    2   2   0    6   6610  1  0   1   0    2   0   0    0   0   0    0    423  0  1   3   0   12   0   0    1   0   0    2   2232  0  0   4   1    3   1   1    0   0   0    0   1241  0  0   0   7   12   0   0    1   0   0    2   2358  0  0   1   5  259   9   0    9   3   1   18  31361  0  0   1   6   30  14   1    2   0   0    0   5573  0  0   0   0    2   0   2    4   0   1    1   1386  0  0   1   1   18   0   3   94   8   1   10  14299  0  0   0   0    1   1   1    9   4   0    8   3310          1  0  0   0   3    6   0   1    1   0   4    5   2111          5  1  0   1   0   18   1   0    6   5   2  116  155All        85  2  1  12  23  371  26  11  129  22   9  168  859         precision    recall  f1-score   support      0       0.54      0.70      0.61        66      1       0.50      0.25      0.33         4      2       1.00      0.05      0.09        22      3       0.33      0.33      0.33        12      4       0.30      0.30      0.30        23      5       0.70      0.83      0.76       313      6       0.54      0.25      0.35        55      7       0.18      0.15      0.17        13      8       0.73      0.66      0.69       142      9       0.18      0.12      0.15        33     10       0.44      0.19      0.27        21     11       0.69      0.75      0.72       155     avg / total       0.64      0.64      0.62       859


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/623588.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号