栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

画混淆矩阵

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

画混淆矩阵

import matplotlib.pyplot as plt
from tensorflow import confusion_matrix
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models
import scipy.io as scio
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
from tensorflow.keras import backend as K,Model # ????
import os  # ????

import tensorflow as tf  # ????

# -*- coding: utf-8 -*-
"""
Created on Thu Jul 30 16:47:47 2020

@author: Noah
"""

import os  # ????

import tensorflow as tf  # ????

from keras import backend as K, Input, Model  # ????

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # ????

config = tf.ConfigProto()  # ????

config.gpu_options.allow_growth = True  # ????

K.set_session(tf.Session(config=config))  # ????

from keras import models
from keras.layers import Flatten, Dense, BatchNormalization, Dropout, Conv1D, SeparableConv1D, Lambda, Concatenate,
    GlobalAveragePooling1D, Activation
input_shape = (128, 2)
inputs = Input(shape=input_shape)
stride = 1

def _group_conv(x, filters, kernel, stride, groups):

    channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
    in_channels = K.int_shape(x)[channel_axis]

    # number of input channels per group
    nb_ig = in_channels // groups
    # number of output channels per group
    nb_og = filters // groups

    gc_list = []
    # Determine whether the number of filters is divisible by the number of groups
    assert filters % groups == 0

    for i in range(groups):
        if channel_axis == -1:
            x_group = Lambda(lambda z: z[:, :, i * nb_ig: (i + 1) * nb_ig])(x)
        else:
            x_group = Lambda(lambda z: z[:, i * nb_ig: (i + 1) * nb_ig, :])(x)
        gc_list.append(Conv1D(filters=nb_og, kernel_size=kernel, strides=stride,
                              padding='same', use_bias=False)(x_group))

    return Concatenate(axis=channel_axis)(gc_list)

def LightNet():
    L = 128 #sample points
    model = models.Sequential()
    x1 = Conv1D(128, 16, activation='relu', padding='same',input_shape=[L,2])(inputs)
    x2 = BatchNormalization()(x1)
    x3 = Dropout(0.003)(x2)

    x4 = _group_conv(x3, filters=64, kernel=8, stride=1, groups=8)
    x5 = BatchNormalization()(x4)
    x6 = Dropout(0.003)(x5)

    x7 = GlobalAveragePooling1D()(x6)

    x8 = Dense(9)(x7)
    predicts = Activation('softmax')(x8)
    model = Model(inputs, predicts)
    model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

model = LightNet()
model.save('m.hdf5')#看模型的大小

from numpy import array
Data_path = "../10modulations_L=128_train6000_val1000_test10000/"
model = LightNet()
model.load_weights("GlobalModel.hdf5")
data_path = Data_path + "test/snr=" + str(20) + ".mat"
data = scio.loadmat(data_path)
x = data.get('IQ')
N = 10000
y1 = np.zeros([N, 1])
y2 = np.ones([N, 1])
y3 = np.ones([N, 1]) * 2
y4 = np.ones([N, 1]) * 3
y5 = np.ones([N, 1]) * 4
y6 = np.ones([N, 1]) * 5
y7 = np.ones([N, 1]) * 6
y8 = np.ones([N, 1]) * 7
y9 = np.ones([N, 1]) * 8
y_flag = np.vstack((y1, y2, y3, y4, y5, y6, y7, y8, y9))
y_flag = array(y_flag)
y = to_categorical(y_flag)
X_test = x
Y_test = y
[loss, acc] = model.evaluate(X_test, Y_test, batch_size=100, verbose=0)
X_pred = model.predict(X_test)

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          cmap=plt.cm.Blues):
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        # plt.text(j, i, ("%.3f" % cm[i, j]),
        plt.text(j, i, ("%.2f" % (cm[i, j])), size=10,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    #cbar = plt.colorbar()
    #cbar.set_ticks(np.linspace(0, 1, 11))
    #cbar.set_ticklabels(('0%', '10%', '20%', '30%', '40%', '50%', '60%', '70%', '80%', '90%', '100%'))
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes,  size=14, family='Times New Roman',rotation=30)
    plt.yticks(tick_marks, classes, size=14, family='Times New Roman')

    plt.title("Confusion Matrix (SNR=20db)", size=14)  # 图像标题
    plt.ylabel('True label', size=13, family='Times New Roman')
    plt.xlabel('Predicted label', size=13, family='Times New Roman')

    plt.colorbar()
    #confusion_matrix_title = 'Confusion matrix'
    # print(confusion_matrix_title)
proba = model.predict(X_test, batch_size=100, verbose=1)
max = np.max(proba, axis=1)
y_pred = np.zeros(shape=(90000, 1))
for i in range(90000):
    for j in range(9):
        if (proba[i][j]==max[i]):
            y_pred[i]=j
cm = confusion_matrix(y_flag, y_pred)
classes = ["2FSK", "4FSK", "8FSK", "BPSK", "QPSK", "8PSK", "16QAM","128QAM", "256QAM"]

plt.figure()
plot_confusion_matrix(cm, classes, normalize=True, cmap=plt.cm.Blues)
plt.show()

data_path = Data_path + "test/snr=" + str(10) + ".mat"
data = scio.loadmat(data_path)
X_test = data.get('IQ')
proba = model.predict(X_test, batch_size=100, verbose=1)
max = np.max(proba, axis=1)
y_pred = np.zeros(shape=(90000, 1))
for i in range(90000):
    for j in range(9):
        if (proba[i][j] == max[i]):
            y_pred[i] = j
cm = confusion_matrix(y_flag, y_pred)
classes = ["2FSK", "4FSK", "8FSK", "BPSK", "QPSK", "8PSK", "16QAM","128QAM", "256QAM"]

plt.figure()
plot_confusion_matrix(cm, classes, normalize=True, cmap=plt.cm.Blues)
plt.savefig("demo9_class_10db.png")
plt.show()
 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/326515.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号