ImageDataGenerator是一个高级类,它允许从多个来源(从
nparrays,从目录…)产生数据,并且包括执行图像增强等功能的实用程序功能。
更新
从keras-preprocessing
1.0.4开始,
ImageDataGenerator提供了
flow_from_dataframe一种解决您的情况的方法。它要求
dataframe和
directory参数定义如下:
dataframe: Pandas dataframe containing the filenames of theimages in a column and classes in another or column/sthat can be fed as raw target data.directory: string, path to the target directory that contains allthe images mapped in the dataframe.
因此,您不再需要自己实施它。
下面的原始答案
对于您的情况,使用描述的数据框,您还可以编写自己的自定义生成器,该生成器将
prepare_data函数中的逻辑用作更简单的解决方案。最好使用Keras的
Sequence对象这样做,因为它允许使用多重处理(如果您使用的是gpu,这将有助于避免瓶颈)。
您可以签出有关对象的文档
Sequence,其中包含一个实现示例。最终,您的代码将遵循以下原则(这是样板代码,您将不得不添加诸如
label2int函数或图像预处理逻辑之类的细节):
from keras.utils import Sequenceclass DataSequence(Sequence): """ Keras Sequence object to train a model on larger-than-memory data. """ def __init__(self, df, batch_size, mode='train'): self.df = df # your pandas dataframe self.bsz = batch_size # batch size self.mode = mode # shuffle when in train mode # Take labels and a list of image locations in memory self.labels = self.df['label'].values self.im_list = self.df['image_name'].tolist() def __len__(self): # compute number of batches to yield return int(math.ceil(len(self.df) / float(self.bsz))) def on_epoch_end(self): # Shuffles indexes after each epoch if in training mode self.indexes = range(len(self.im_list)) if self.mode == 'train': self.indexes = random.sample(self.indexes, k=len(self.indexes)) def get_batch_labels(self, idx): # Fetch a batch of labels return self.labels[idx * self.bsz: (idx + 1) * self.bsz] def get_batch_features(self, idx): # Fetch a batch of inputs return np.array([imread(im) for im in self.im_list[idx * self.bsz: (1 + idx) * self.bsz]]) def __getitem__(self, idx): batch_x = self.get_batch_features(idx) batch_y = self.get_batch_labels(idx) return batch_x, batch_y
您可以像自定义生成器一样传递此对象来训练模型:
sequence = DataSequence(dataframe, batch_size)model.fit_generator(sequence, epochs=1, use_multiprocessing=True)
如下所述,不需要实现改组逻辑。
shuffle只需
True在
fit_generator()调用中将参数设置为即可。从文档:
shuffle:布尔值。是否在每个纪元开始时重新整理批次的顺序。仅用于Sequence实例(keras.utils.Sequence)。当steps_per_epoch不为None时无效。



