Detector🎯3——mmdetection 训练文件详解train.py

"mmdetection 开始训练模型"

Posted by fuhao7i on January 13, 2021

mmdetection 详解:训练自己的模型

1. train_detector() 函数详解

1
2
3
4
5
6
7
train_detector(
    model,
    train_dataset,
    cfg,
    distributed=distributed,
    validate=args.validate,
    logger=logger)

参数:

model : 构建的网络模型
train_dataset : 构建的训练数据集
cfg : 读取的Config py文件
distributed : 是否是分布式训练
validate : whether to evaluate the checkpoint during training
logger : 日志信息

接下来我们详细看一下train_detector()函数
./mmdet/apis/train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)
    
    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
    else:
        _non_dist_train(model, dataset, cfg, validate=validate)

1.1 _non_dist_train() (单个GPU)非分布式训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def _non_dist_train(model, dataset, cfg, validate=False):
    if validate:
        raise NotImplementedError('Built-in validation is not implemented '
                                  'yet in not-distributed training. Use '
                                  'distributed training or test.py and '
                                  '*eval.py scripts instead.')
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
    
    # build runner
    optimizer = build_optimizer(model, cfg.optimizer, cfg.get('optimizer_exclude_arch'))

    arch_name = None
    optimizer_arch = None
    if 'optimizer_arch' in cfg:
        raise NotImplementedError
    
    runner = Runner(model, batch_processor, optimizer, optimizer_arch, cfg.work_dir, cfg.log_level, arch_name=arch_name)

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
        optimizer_arch_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config, optimizer_arch_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    
    if 'optimizer_arch' in cfg:
        raise NotImplementedError
    else:
        data_loaders = [
            build_dataloader(
                dataset,
                cfg.data.imgs_per_gpu,
                cfg.data.workers_per_gpu,
                cfg.gpus,
                dist=False)
        ]
        runner.run(data_loaders, None, cfg.workflow, cfg.total_epochs)

1.1.1 构建Runner实例

1
2
3
runner = Runner(model, batch_processor, optimizer, optimizer_arch, cfg.work_dir, cfg.log_level, arch_name=arch_name)
...
runner.run(data_loaders, None, cfg.workflow, cfg.total_epochs)

这里我们来看Runner第二个参数batch_processor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def batch_processor(model, data, train_mode, **kwargs):
    losses = model(**data)

    losses_ = losses[0]
    loss_latency = losses[1]
    if loss_latency is not None:
        losses_['loss_latency'] = loss_latency

    loss, log_vars = parse_losses(losses_)
   
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs