栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

MMSegmentation 训练测试全流程

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

MMSegmentation 训练测试全流程

MMSegmentation 训练测试全流程

1.按照执行顺序的流程梳理

Level 0: 运行 Shell 命令:Level 1: 在 tools/train.py 内:Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterbasedRunner 类的 run 函数内部:Level 4: 转进到 IterbasedRunner 类的 train 函数内部Level 5: 转进到 evalHook 类实例的 after_train_iter 函数内部: 4.函数说明:5.疑问解答参考链接:

括号的部分可以不看!是debug经过的内容,有些事调用了mmcv库的函数,只想看看流程不需要细看!

1.按照执行顺序的流程梳理 Level 0: 运行 Shell 命令:

python tools/train.py ${CONFIG_FILE [optional arguments] Level 1: 在 tools/train.py 内:

读取各种 config: cfg = Config.fromfile(args.config)创建 model: model = build_segmentor(cfg.model, train_cfg, test_cfg)创建 training dataset: datasets = [build_dataset(cfg.data.train)]()

通过Config类的__getattr__函数:value = super(ConfigDict, self).__getattr_获取数据和数据增强信息并返回value转到mmseg/datasets/builder.py内的build_dataset函数,获取dataset:dataset = build_from_cfg(cfg, DATASETS, default_args)转到/usr/local/lib/python3.8/dist-packages/mmcv/utils/registry.py内的build_from_cfg函数:

args = cfg.copy()获取数据格式类型:obj_type = args.pop('type'),比如obj_type:ADE20KDatase通过数据格式obj_type获得类obj_cls = registry.get(obj_type),比如获取return obj_cls(**args) (转到/usr/lib/python3.8/typing.py的Generic类的__new__函数:obj = super().__new__(cls))转到mmseg/datasets/ade.py中的ADE20KDataset类的__init__函数:super(ADE20KDataset, self).__init__(**转到mmseg/datasets/custom.py中ADE20KDatase类继承的CustomDataset类

调用loading.py中LoadAnnotations类进行初始化,获得image和mask的地址等信息,并获取image和mask名字的dict:self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,self.ann_dir,self.seg_map_suffix, self.split)实例对象做运算时,就会调用CustomDataset类中的__getitem__()__:self.prepare_train_img(idx)调用prepare_train_img函数:self.pipeline(results),调用mmseg/datasets/pipelines/loading.py的LoadImageFromFile类和其他数据增强 创建 validation dataset: datasets.append(build_dataset(val_dataset))将 model, data, config 喂给训练函数: train_segmentor(model, datasets, cfg) Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:

创建 dataloader: data_loaders = [build]()_dataloader(dataset, config)]将 model 搬到 GPU 上去: model = MMDataParallel(model.cuda(), cfg)创建 optimizer: optimizer = build_optimizer(model, cfg)创建 runner: runner = build_runner(model, cfg, optimizer)给 runner 注册 training hooks: runner.register_training_hooks(cfg)给 runner 注册 validation hooks: runner.register_hook(eval_hook(val_dataloader, eval_cfg))

这个 eval_hook 是 evalHook 类实例, 其重写了 after_train_iter 和 after_train_epoch 两个方法, 在 IterbasedRunner 中用的是 after_train_iter。 开始训练 runner.run(data_loaders, cfg.workflow) Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterbasedRunner 类的 run 函数内部:

Training 模式, mode = 'train', i = 0, 运行 iter_runner(iter_loaders[i](), **kwargs)

实质上是在运行 IterbasedRunner类的 train 函数: train(iter_loaders[0](), **kwargs)从 while self.iter < self._max_iters: 可以看到, 这个 train 函数一共会被调用 self._max_iters 次从中也可以看到这个 train 函数其实只负责做一个 batch 数据的 forward 计算 Validation 模式, 此处其实没有运行

mmseg 的所有 setting 都是 workflow = [('train', 1)]实际上的 validation 是通过在 after_train_epoch 节点调用 evalHook 对象的 after_train_iter方法实现的。 Level 4: 转进到 IterbasedRunner 类的 train 函数内部

读取一个 batch 的数据: data_batch = next(data_loader)调用 model 的 train_step 函数计算 loss: outputs = self.model.train_step(data_batch)尝试选择性进行 validation:self.call_hook('after_train_iter')

实质上是调用 evalHook 类实例的 after_train_iter 函数; Level 5: 转进到 evalHook 类实例的 after_train_iter 函数内部:

如果当前迭代数不能够被 interval 整除, 就不做 validation: if not self.every_n_iters(runner, self.interval): return如果能被整除, 计算一下 validation set 上的结果: results = single_gpu_test(model, dataloader)

这一步就是 enumerate 一下 data_loader, 对于每个 batch 都用 model forward 一下, 把 result 都 append 起来得到一个 list results, 就不再展开了 对于分割结果再调用 dataset 的 evaluate 函数计算一下 mIoU, mDice, mFscore 等 metric 数值

其实就是通过调用下 mmseg.core 里面的 eval_metrics 函数调用 total_intersect_and_union 函数计算下上述数值 4.函数说明:

self.pipeline = Compose(pipeline)

Compose:把函数组合起来,每个函数的返回值是下一个函数的参数

print_log(f’Loaded {len(img_infos)} images’, logger=get_root_logger())

print_log:打印日志

target = torch.where(target == ignore_index, target.new_tensor(0), target)

torch.where:查找 target 中值为ignore_index(255)的值转为0,new_tensor:target.new_tensor是将target的值copy一份,不共享内存,new_tensor(0)指值为0同样size矩阵 5.疑问解答

CustomDataset类中pre_eval函数的ignore_index=255是起什么作用的? 是不计算255的loss吗

在mmseg/core/evaluation/metrics.py函数中找到了答案intersect_and_union函数中计算IOU的时候,将ignore_index=255的值忽略掉:mask = (label != ignore_index),相当于不计算背景的准确率,获取到的相当于是召回率Recall。需要注意的是,其中reduce_zero_label=True时,是将像素值为0的转为255:label[label == 0] = 255,会在mask = (label != ignore_index)处一并忽略 注意:255值在标注员工标注过程中代表不需要标注的区域,相当于背景,在需要标注的区域,背景值是0 参考链接:

【1】MMSegmentation 训练测试全流程及其关键节点

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/743794.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号