栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

【零基础-3】PaddlePaddle学习Bert

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【零基础-3】PaddlePaddle学习Bert

概要

【零基础-1】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121488335【零基础-2】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121523268

Cell 7
# 创建dataloader
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

    return paddle.io.DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)
Snippet  1
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):

create_dataloader,创建数据加载器,输入数据集dataset、模式mode(默认为训练集)、batchify_fn(未知,暂时理解成batchify_function,即batch化的函数)、trans_fn(转换样本的函数)。

Snippet 2
if trans_fn:
        dataset = dataset.map(trans_fn)

如果传入了trans_fn,就使用trans_fn将dataset进行一个转换,dataset.map的api文档如下

dataset — PaddleNLP 文档https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.datasets.dataset.html?highlight=dataset.map#paddlenlp.datasets.dataset.MapDataset.map

Snippet 3
shuffle = True if mode == 'train' else False

如果是训练集,就打乱,否则不打乱,这里的语法相当于C、Java的三目运算符

shuffle = mode == 'train' ? true : false 
Snippet 4
if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

如果传入的是训练集,则调用paddle.io.DistributedBatchSampler处理得到batch_sampler,如果传入的不是训练集,则调用paddle.io.BatchSampler处理得到batch_sampler。

这里为什么要得到batch_sampler呢?

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/601174.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号