栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

八. airplane,lake图像分类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

八. airplane,lake图像分类

import tensorflow as tf
import numpy as np
import glob#glob是python自带的一个操作文件的相关模块
import random

def load_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image,channels=3)#jpg,png,gif....经过压缩编码,所以读取之后要解码还原成一个三维矩阵
    image = tf.image.resize(image,[256,256])#统一尺寸256*256
    #image = tf.cast(image, tf.float32)  # uint8变为float32
    image = image/255.0
    return image

if __name__ == '__main__':
    all_image_path = glob.glob('2_class/*/*.jpg')
    #['2_class\airplane\airplane_001.jpg', '2_class\airplane\airplane_002.jpg', ''''']读取图像路径
    random.shuffle(all_image_path)
    label_to_index = {
        'airplane':0,
        'lake':1
    }
    all_labels = all_image_path[10].split("\")[-1].split("_")[0]#airplane或lake

    all_labels = [label_to_index.get(
        i.split("\")[-1].split("_")[0])
        for i in all_image_path
    ]
    #all_labels: [0, 1, 0, 0, 0, 0, 0, 1, 0,'''''']
    img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
    img_ds = img_ds.map(load_image)#
    label_ds = tf.data.Dataset.from_tensor_slices(all_labels)
    img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))#
    image_count = len(all_image_path)
    test_count = int(image_count*0.2)
    train_count = image_count - test_count#训练集数量
    train_ds = img_label_ds.skip(test_count)#跳过前面
    test_ds = img_label_ds.take(test_count)#取前面
    train_ds = train_ds.repeat().shuffle(100).batch(16)#100大小的缓存区乱序,
    test_ds = test_ds.batch(16)
    print(train_ds)



    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(64,(3,3),input_shape=(256,256,3),padding='same',activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())#批标准化
    model.add(tf.keras.layers.Conv2D(64,(3,3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(128,(3,3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Conv2D(128,(3,3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(256,(3,3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Conv2D(256,(3,3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(512, (3, 3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Conv2D(512, (3, 3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(512, (3, 3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Conv2D(512, (3, 3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Conv2D(512, (3, 3),activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.GlobalAveragePooling2D())
    model.add(tf.keras.layers.Dense(1024,activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Dense(256, activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    model.add(tf.keras.layers.BatchNormalization())

    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.0001),
        loss=tf.keras.losses.BinaryCrossentropy(),#BinaryCrossentropy 函数
        metrics=['acc']
    )
    history = model.fit(
        train_ds,
        epochs=10,
        steps_per_epoch=train_count//16,
        validation_data=test_ds,
        validation_steps=test_count//16
    )
    print(history)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331162.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号