栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

Tensorflow的批量学习

Tensorflow的批量学习

使用Tensorflow最方便的在于可以使用fit函数直接封装训练,但是如果要处理大数据样本,就可能需要先构造生成器了。

使用“yield”

对于函数返回 yield 的通俗理解就是返回了一个存储函数的地址,在某个空间有这个暂时用不到的函数,而这个函数本来是要返回一个容器,比方 list :

def func2():
    yield [1,2]
 
b = func2()
print(b)

def func3():
    for x in range(2):
        yield x ** 2
 
c = func3()
 
for x in c:
    print(x)

只有当生成器调用成员方法时,生成器中的代码才会执行。

一个简单的minibatch

def minibatches(inputs=None, batch_size=10):
    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt]  # 提取相应的样本数据和标签数据
a = minibatches(list)
for i in a:
    a, b = data_generation(i)
    print(a.shape)
构造类似pytorch的sequence生成器
class DataGenerator(keras.utils.Sequence):

    def __init__(self, datas, batch_size=1, shuffle=True):
        self.batch_size = batch_size * 10
        self.datas = datas
        self.indexes = np.arange(len(self.datas))
        self.shuffle = shuffle

    def __len__(self):
        # 计算每一个epoch的迭代次数

        return math.ceil(len(self.datas) / float(self.batch_size))

    def __getitem__(self, index):
        # 生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
        # 生成batch_size个索引
        batch_indexs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # 根据索引获取datas集合中的数据
        batch_datas = [self.datas[k] for k in batch_indexs]

        # 生成数据
        X, y = self.data_generation(batch_datas)

        return X, y

    def on_epoch_end(self):
        # 在每一次epoch结束是否需要进行一次随机,重新随机一下index
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
# a = DataGenerator(list)
# print(a.__getitem__(0).shape)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/780431.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号