栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PCA降维和TSNE降维的对比

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PCA降维和TSNE降维的对比

# -*- coding: utf-8 -*-
"""
Created on Mon Nov 15 21:48:20 2021
@author: guangjie2333
"""
import keras
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from keras.preprocessing import image
from sklearn import svm
from sklearn import metrics
from sklearn.model_selection import GridSearchCV
from sklearn.manifold import TSNE


np.random.seed(1)

#宏定义
switch = 0

'数据解析'
BreastMnistData = np.load('breastmnist.npz')
print(BreastMnistData.files)
train_images = BreastMnistData['train_images']
val_images = BreastMnistData['val_images']
test_images = BreastMnistData['test_images']
train_labels = BreastMnistData['train_labels']
val_labels = BreastMnistData['val_labels']
test_labels = BreastMnistData['test_labels']



X = []
'数据展示'
print(train_images.shape)
print(train_images.shape[0])

'训练集PCA降维'
train_feature = []
for idx, image in enumerate(train_images):
    img_feature = image.flatten()
    train_feature.append(img_feature)


train_feature = np.array(train_feature)
print('trian_feature.shape:', train_feature.shape)

if switch == 0 :
    pca = PCA(n_components=3) # 降成3维    
    train_pca = pca.fit_transform(train_feature)
    print(train_pca)
    print('trian_pca:shape:', train_pca.shape)
    x, y, z = train_pca[:,0], train_pca[:,1], train_pca[:,2]
else:
    tsne = TSNE(n_components=3) # 降成3维    
    train_tsen = tsne.fit_transform(train_feature)
    print(train_tsen)
    print('train_tsen:shape:', train_tsen.shape)
    x, y, z = train_tsen[:,0], train_tsen[:,1], train_tsen[:,2]

'''
# 2D绘图
plt.figure(figsize=(12, 6))
plt.scatter(train_pca[:,0], train_pca[:,1], c=train_labels)
plt.colorbar()
plt.title('Use of PCA')
'''

if switch == 0 :
    # 3D绘图n",
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111,projection='3d')  # 创建一个三维的绘图工程"
    
    ax.scatter(x, y, z, c=train_labels)
    ax.set_zlabel('Z')  # 坐标轴"
    ax.set_ylabel('Y')
    ax.set_xlabel('X')
    plt.title("method of PCA")
else:
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111,projection='3d')  # 创建一个三维的绘图工程"
    
    ax.scatter(x, y, z, c=train_labels)
    ax.set_zlabel('Z')  # 坐标轴"
    ax.set_ylabel('Y')
    ax.set_xlabel('X')
    plt.title("method of TSNE")




'验证集PCA降维'
val_feature = []
for idx, image in enumerate(val_images):
    img_feature = image.flatten()
    val_feature.append(img_feature)


val_feature = np.array(val_feature)
print('val_feature.shape:', val_feature.shape)

if switch == 0 :
    pca = PCA(n_components=3) # 降成3维    
    val_pca = pca.fit_transform(val_feature)
    print(val_pca)
    print('val_pca:shape:', val_pca.shape)
else:
    tsne = TSNE(n_components=3) # 降成3维    
    val_tsne = tsne.fit_transform(val_feature)
    print(val_tsne)
    print('val_tsne:shape:', val_tsne.shape)


'测试集PCA降维'
test_feature = []
for idx, image in enumerate(test_images):
    img_feature = image.flatten()
    test_feature.append(img_feature)


test_feature = np.array(test_feature)
print('val_feature.shape:', test_feature.shape)

if switch == 0 :
    pca = PCA(n_components=3) # 降成3维    
    test_pca = pca.fit_transform(test_feature)
    print(test_pca)
    print('val_pca:shape:', test_pca.shape)
else:
    tsne = TSNE(n_components=3) # 降成3维    
    test_tsne = tsne.fit_transform(val_feature)
    print(test_tsne)
    print('test_tsne:shape:', test_tsne.shape)



#基于SVM验证
'''
parameters={'kernel':['linear','rbf','sigmoid','poly'],'C':np.linspace(0.1,20,5),'gamma':np.linspace(0.1,20,5)}
svc = svm.SVC()
clf = GridSearchCV(svc,parameters,cv=5,scoring='accuracy')
'''

clf = svm.SVC(kernel = 'poly', C = 1)
if switch == 0 :
   clf.fit(train_pca,train_labels)
   y_val_predict = clf.predict(val_pca)
else:
   clf.fit(train_tsen,train_labels)
   y_val_predict = clf.predict(val_tsne)
    
valResult = metrics.accuracy_score(val_labels,y_val_predict)
        
#计算精度
print("Accuracy",valResult);

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/588981.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号