栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

利用VGG16卷积神经网络模型数据做五种图片检测

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

利用VGG16卷积神经网络模型数据做五种图片检测

把文件vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5放入C:UsersDELL.kerasmodels目录下,本文采用VGG16模型训练好的权重和偏置值搭建卷积神经网络,其中没有更改卷积层和池化层模型结构,使用两层神经网络简单识别data目录下的五种图片进行分类。

from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
import tensorflow as tf
import numpy as np


# 读取本地图片
class TransferModel(object):
    def __init__(self):
        # 定义测试和图片的变化方式
        self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
        self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0)
        # 指定训练数据和测试数据的目录
        self.train_dir = "./data/train"
        self.test_dir = "./data/text"
        # 定义图片网络大小参数
        self.image_size = (224, 224)
        self.batch_size = 32
        # 定义迁移学习的类型模型,不包含全连接层模型加载
        self.base_model = VGG16(include_top=False)
        self.lable_dict = {
            '0': 'bus',
            '1': 'cat',
            '2': 'dog',
            '3': 'flower',
            '4': 'tree'
        }

    def get_local_data(self):
        # 读取本地数据
        # 获取数据集
        train_gen = self.train_generator.flow_from_directory(self.train_dir, shuffle=True,
                                                             target_size=self.image_size, class_mode='binary',
                                                             batch_size=self.batch_size)
        # 训练数据集
        test_gen = self.test_generator.flow_from_directory(self.test_dir, shuffle=True,
                                                           target_size=self.image_size, class_mode='binary',
                                                           batch_size=self.batch_size)
        return train_gen, test_gen

    def refine_base_model(self):
        # 微调VGG,减少迁移学习参数数量
        # 获取原先的notop输出,在输出后面增加结构定义新的迁移学习模型
        x = self.base_model.outputs[0]
        x = keras.layers.GlobalAveragePooling2D()(x)
        x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
        y_prediction = keras.layers.Dense(5, activation=tf.nn.softmax)(x)
        five_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_prediction)
        return five_model

    def freeze_model(self):
        # 冻结Vgg模型,根据数据量来判断,获取所有车=层,返回层的列表
        for layers in self.base_model.layers:
            layers.trainable = False

    def compile(self, model):
        model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.sparse_categorical_crossentropy,
                      metrics=["accuracy"])
        return None

    def fit_generator(self, model, train_gen, test_gen):
        # 每一次迭代的准确数
        modelckpt = keras.callbacks.ModelCheckpoint(
            'D:/PYCHARM ITEM/vgg16pretrain/ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
            monitor='val_acc',
            save_weights_only=True,
            save_best_only=True,
            mode='auto',
            period=1)
        model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt])
        return None

    def predict(self, model):
        # 加载模型
        model.load_weights("./ckpt/transfer_02-0.73.h5")
        # 读取图片
        image = load_img("D:/PYCHARM ITEM/vgg16pretrain/data/text/flower/src=https://www.mshxw.com/skin/sinaskin/image/nopic.gif",
                         target_size=(224, 224))
        image = img_to_array(image)
        img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])

        # model.predict()
        # 预测结果进行处理
        image = preprocess_input(img)
        predictions = model.predict(image)
        res = np.argmax(predictions, axis=1)
        print(self.lable_dict[str(res[0])])

# def predict():
#     model = VGG16()
#     print(model.summary())


if __name__ == '__main__':
    tm = TransferModel()
    # 训练
    # train_gen, test_gen = tm.get_local_data()
    # # print(train_gen)
    # # print(test_gen)
    # # for data in train_gen:
    # #     print(data[0].shape, data[1].shape)
    # # print(tm.base_model.summary())
    # # print(tm.refine_base_model())
    # model = tm.refine_base_model()
    # tm.freeze_model()
    # tm.compile(model)
    # tm.fit_generator(model, train_gen, test_gen)
    # print(model.summary())
    # 测试
    model = tm.refine_base_model()
    tm.predict(model)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/588841.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号