下面是train_detector的主干,我删除了异常判断、版本兼容、分布式训练等内容,下面列出来的是我认为比较重要的部分。
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
runner_type = 'EpochbasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
#DataLoader,是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,
#一般只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的
#Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]
#构建优化器,optimizer目的:优化SGD,训练快速收敛并且保证准确率
optimizer = build_optimizer(model, cfg.optimizer)
# build runner runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# register hooks 注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
#加载模型
runner.load_checkpoint(cfg.load_from)
# runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练
runner.run(data_loaders, cfg.workflow)



