栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Keras VGG16 训练自己图片程序

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Keras VGG16 训练自己图片程序

tensorflow-gpu                     2.2.0

 Keras                              2.3.1

import os
import keras
from keras.applications import VGG16
from keras import models
from keras import layers
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers


train_dir = 'E://pythonsave//mnist_data//picture//picture//train'                           #训练  

#validation_dir =                       #验证

test_dir =  'E://pythonsave//mnist_data//picture//picture//test'                           #测试


conv_base = VGG16(weights='imagenet',
                include_top=False,
                input_shape=(150, 150,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(units = 4, activation='softmax')) #“4”为分类数量,如需分10项,改为10




train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
        train_dir, 
        target_size=(150, 150), 
        batch_size=20,
        class_mode='categorical')   #如果是二分类问题该值可设为‘binary’

test_datagen = ImageDataGenerator(rescale=1./255)

validation_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='categorical')  #如果是二分类问题该值可设为‘binary’


callbacks_list = [
        keras.callbacks.EarlyStopping
        (
                monitor='acc',
                patience=3,
        ),
        keras.callbacks.ModelCheckpoint(
                filepath='E://pythonsave//laji.h5',
                save_best_only=True,
        )]


model.compile(loss='categorical_crossentropy',
        optimizer=optimizers.RMSprop(lr=2e-5),
        metrics=['acc'])






history = model.fit_generator(
 train_generator,
 steps_per_epoch=100,
 epochs=30,
 callbacks=callbacks_list,   
 validation_data=validation_generator,
 validation_steps=50)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/887481.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号