栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

mmdetection - 训练过程之train

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

mmdetection - 训练过程之train

下面是train_detector的主干,我删除了异常判断、版本兼容、分布式训练等内容,下面列出来的是我认为比较重要的部分。

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    runner_type = 'EpochbasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']
        
#DataLoader,是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,
#一般只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的
#Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # `num_gpus` will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
            runner_type=runner_type,
            persistent_workers=cfg.data.get('persistent_workers', False))
        for ds in dataset
    ]

#构建优化器,optimizer目的:优化SGD,训练快速收敛并且保证准确率
   optimizer = build_optimizer(model, cfg.optimizer)
# build runner runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。
    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))
    # register hooks 注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))
#加载模型
    runner.load_checkpoint(cfg.load_from)
 # runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练
    runner.run(data_loaders, cfg.workflow)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/739902.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号