栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

python读取CIFAR10数据集并将数据集转换为PNG格式存储

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

python读取CIFAR10数据集并将数据集转换为PNG格式存储

CIFAR10数据集介绍

CIFAR10数据集包括10类图像,每张图像的大小为32*32,包含如上图的十个类别的对象。每个类都包含6000张图片,总共有60000张图片,数据集平衡。其中,训练组图像包含50000张图片,测试集包含10000张图像。

数据集的下载

数据集地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
也可以使用pytorch中的方法来获取数据集:

trainset = torchvision.datasets.CIFAR10(root='存储路径',
                                        train=True,
                                        download=True,
                                        transform = transform,
                                        )
testset = torchvision.datasets.CIFAR10(root='存储路径',
                                       train=False,
				       download=True,
                                       transform = transform,
                                       )

下载后的数据集如下:

包含五个训练batch和一个测试batch,每个batch包含一万张图片。在做深度学习训练的时候直接从batch中读取数据就好,也可以转换为PNG或者JPG图片格式来再进行读取和查看图像数据。
读取代码如下:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/5 下午1:02

import cv2
import numpy as np
from six.moves import cPickle as pickle
#解压缩二进制文件
def unpack(file):
    fo = open(file, "rb")
    dict = pickle.load(fo,encoding='iso-8859-1')
    fo.close()
    return dict

## unpack trainset

for i in range(1,6):
    data_name = "训练batch路径" + str(i)
    Xtr = unpack(data_name)
    print(data_name + 'is loading....')

    for j in range(10000):
        img = np.reshape(Xtr['data'][j],(3,32,32))
        img = img.transpose(1,2,0)
        img_name = 'train/' + str(Xtr['labels'][j]) + '_' + str(j+ (i-1)*10000) + '.jpg'
        cv2.imwrite(img_name,img)
    print(data_name + 'is loaded....')

testXtr = unpack('测试batch路径')
for i in range(0,10000):
    img = np.reshape(testXtr['data'][i],(3,32,32))
    img = img.transpose(1,2,0)
    img_name = 'test/' + str(testXtr['labels'][i]) + '_' + str(i) + '.jpg'
    cv2.imwrite(img_name, img)

在python3中解压二进制文件要带上这一句:

 dict = pickle.load(fo,encoding='iso-8859-1')

否则会出现编码错误。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/423523.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号