栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

python 使用numpy计算混淆矩阵

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

python 使用numpy计算混淆矩阵

python 使用numpy计算混淆矩阵

假如一个模型要预测的类别有三个,分别为A、B、C,使用模型预测测试集得到以下结果:

我们一列一列来看,先看第一列:30、15、5

这里我们的测试集有且只有三个分类A、B、C;也就是真实分类A、B、C就对应着测试集的总体,对于一个样本的预测也只可能是这三者之一。

模型预测值为A的,实际标签不一定就是A,但它一定是A、B、C三者之一,这里预测为A的前提下:真实值为A的有30个、真实值为B的有15个、真实值为C的有5个。

上述表格用numpy表示如下

import numpy as np
# 混淆矩阵
c_matrix = np.array([[30,  7,   3],
                    [ 15, 22,  3],
                    [ 5,   1,  14]])
print(c_matrix.shape)  # (3,3)
print(c_matrix[0][1])  # 7

也可对混淆矩阵进行标准化,使其值在0到1之间

# 混淆矩阵标准化(这里使用L1规范化,是对每一行来说规范化)
print(c_matrix.sum(axis=1))
print(c_matrix.sum(axis=1)[:, np.newaxis])  
c_matrix = c_matrix / c_matrix.sum(axis=1)[:, np.newaxis]
print(c_matrix)
[40 40 20]
[[40]
 [40]
 [20]]
[[0.75  0.175 0.075]
 [0.375 0.55  0.075]
 [0.25  0.05  0.7  ]]

在测试模型时计算

confusion_matrix = np.zeros( (len(class_names), len(class_names)) )  # 混淆矩阵
for images, labels in test_ds.take(total_batch):
        labels = labels.numpy()
        predictions = model.predict(images)
        score = tf.nn.softmax(predictions)
        for index, elem in enumerate(score):
            r, c = np.argmax(elem), labels[index]
            confusion_matrix[r][c] += 1
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/822883.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号