栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Python神经网络5之数据读取2

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Python神经网络5之数据读取2

Python神经网络5之数据读取2
  • 数据读取
    • TFRecords
      • TFRecords文件
      • 案例:CIFAR10数据存入TFRecords文件
      • 读取TFRecords文件API
      • 案例:读取CIFAR的TFRecords文件

数据读取 TFRecords TFRecords文件

TFRecords其实是一种二进制文件,虽然不如其他格式好理解,但是能够更好的利用内容,更方便复制和移动,并且不需要单独的标签文件
使用步骤:
1.获取数据
2.将数据填入到Example协议内存块(protocol buffer)
3.将协议内存块序列化为字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecords文件

  • 文件格式 *.tfrecords

Example结构解析:

  • tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段Features)
    Features包含了一个Features字段
    Feature中包含要写入的数据,并指明数据类型。
    这是一个样本的结构,批数据需要循环存入这样的结构
  • tf.train.Example(features=None)
    写入tfrecords文件
    features:tf.train.Features类型的特征实例
    return:example格式协议块
  • tf.train.Features(feature=None)
    构建每个样本的信息键值对
    feature:字典数据,key为要保存的名字
    value为tf.train.Feature实例
    return:Features类型
  • tf.train.Feature(options)
    • options:例如
      • bytes_list=tf.train.BytesList(value=[Bytes])
      • int64_list=tf.train.Int64List(value=[Value])
    • 支持存入的类型如下
    • tf.train.Int64List(value=[Value])
    • tf.train.BytesList(value=[Bytes])
    • tf.train.FloatList(value=[value])

example = tf.train.Example(features=tf.train.Features(feature={
“image”:tf.train.Feature(bytes_list=tf.train. BytesList(value=[image])),
“label”:tf.train.Feature(int64_list=tf.train. Int64List(value=[label]))
}))
将example序列化:example.SerializeToString()

案例:CIFAR10数据存入TFRecords文件
  • 构造存储实例,tf.python_io.TFRecordWriter(path)
    • 写入tfrecords文件
    • path:TFRecords文件的路径
    • return:写文件
      • method方法
        • write(record):向文件中写入一个example
        • close():关闭文件写入器
  • 循环将数据填入到Example协议内存块(protocol buffer)
    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本特征值和目标值一起写入tfrecords文件
        :param image:
        :param label:
        :return:
        """
        with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                # print("tfrecords_image:n",image)
                # print("tfrecords_label:n",label)
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())
        return None;

生成cifar10.tfrecords文件:




读取TFRecords文件API

读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤,从TFRecords文件中读取数据,可以使用tf.TFRecordReader的tf.parse_single_example解析器,这个操作可以将Example协议内存块(protocol buffer)的解析为张量

  • tf.parse_single_example(serialized,features=None,name=None)
    解析一个单一的Example原型
    serialized:标量字符串Tensor,一个序列化的Example
    features:dict字典数据,键为读取的名字,值为FixedLenFeature
    return:一个键值对组成的字典,键为读取的名字
  • tf.FixedLenFeature(shape,dtype)
    shape:输入数据的形状,一般不指定,为空列表
    dtype:输入数据类型,与存储进文件的类型要一致
    类型只能是float32,int64,string
案例:读取CIFAR的TFRecords文件
  1. 构造文件名队列
  2. 读取和解码
    读取
    解析Example
    解码
  3. 构造批处理队列
    def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1.构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2.读取与解码
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:n", image)
        print("read_tf_label:n", label)
        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        print("image_reshaped:n", image_reshaped)

        # 3.构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:n", image_batch)
        print("label_batch:n", label_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
            image_value,label_value=sess.run([image_batch,label_batch])
            print("image_value:n", image_value)
            print("label_value:n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None;




全部代码:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os


class Cifar(object):

    def __init__(self):
        # 初始化操作
        self.height = 32
        self.width = 32
        self.channels = 3

        # 设置图像字节数
        self.image = self.height * self.width * self.channels
        self.label = 1
        self.sample = self.image + self.label

    def read_binary(self):
        """
        读取二进制文件
        :param file_list:
        :return:
        """
        # 1.构造文件名队列
        filename_list = os.listdir("./cifar-10-batches-bin")
        print("file_name:n", filename_list)
        # 构造文件名路径列表
        file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in filename_list if file[-3:] == "bin"]
        print("file_list:n", file_list)
        file_queue = tf.train.string_input_producer(file_list)

        # 2.读取与解码
        reader = tf.FixedLengthRecordReader(self.sample)
        # key文件名 value一个样本
        key, value = reader.read(file_queue)
        print("key:n", key)
        print("value:n", value)
        # 解码阶段
        image_decoded = tf.decode_raw(value, tf.uint8)
        print("image_decoded:n", image_decoded)

        # 将目标值和特征值切片切开
        label = tf.slice(image_decoded, [0], [self.label])
        image = tf.slice(image_decoded, [self.label], [self.image])
        print("label:n", label)
        print("image:n", image)

        # 调整图片形状
        image_reshaped = tf.reshape(image, shape=[self.channels, self.height, self.width])
        print("image_reshaped:n", image_reshaped)

        # 转置,将图片的顺序转为height,width,channels
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:n", image_transposed)

        # 3.批处理
        label_batch, image_batch = tf.train.batch([label, image_transposed], batch_size=100, num_threads=2, capacity=100)
        print("label_batch:n", label_batch)
        print("image_batch:n", image_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            label_value, image_value = sess.run([label_batch, image_batch])

            print("label_value:n", label_value)
            print("image_value:n", image_value)

            # 回收线程
            coord.request_stop()
            coord.join(threads)
        return label_value, image_value

    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本特征值和目标值一起写入tfrecords文件
        :param image:
        :param label:
        :return:
        """
        with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                # print("tfrecords_image:n",image)
                # print("tfrecords_label:n",label)
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())
        return None;

    def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1.构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2.读取与解码
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:n", image)
        print("read_tf_label:n", label)
        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        print("image_reshaped:n", image_reshaped)

        # 3.构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:n", image_batch)
        print("label_batch:n", label_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
            image_value,label_value=sess.run([image_batch,label_batch])
            print("image_value:n", image_value)
            print("label_value:n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None;


if __name__ == "__main__":
    # 实例化Cifar
    cifar = Cifar()
    # label_value,image_value=cifar.read_binary()
    # cifar.write_to_tfrecords(image_value,label_value)
    cifar.read_tfrecords()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/823029.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号