mmdetection 详解:训练自己的模型
1. train_detector() 函数详解
1
2
3
4
5
6
7
train_detector(
model,
train_dataset,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
参数:
model : 构建的网络模型
train_dataset : 构建的训练数据集
cfg : 读取的Config py文件
distributed : 是否是分布式训练
validate : whether to evaluate the checkpoint during training
logger : 日志信息
接下来我们详细看一下train_detector()函数
。
./mmdet/apis/train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate)
else:
_non_dist_train(model, dataset, cfg, validate=validate)
1.1 _non_dist_train() (单个GPU)非分布式训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def _non_dist_train(model, dataset, cfg, validate=False):
if validate:
raise NotImplementedError('Built-in validation is not implemented '
'yet in not-distributed training. Use '
'distributed training or test.py and '
'*eval.py scripts instead.')
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
optimizer = build_optimizer(model, cfg.optimizer, cfg.get('optimizer_exclude_arch'))
arch_name = None
optimizer_arch = None
if 'optimizer_arch' in cfg:
raise NotImplementedError
runner = Runner(model, batch_processor, optimizer, optimizer_arch, cfg.work_dir, cfg.log_level, arch_name=arch_name)
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=False)
else:
optimizer_config = cfg.optimizer_config
optimizer_arch_config = cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config, optimizer_arch_config,
cfg.checkpoint_config, cfg.log_config)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
if 'optimizer_arch' in cfg:
raise NotImplementedError
else:
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
runner.run(data_loaders, None, cfg.workflow, cfg.total_epochs)
1.1.1 构建Runner实例
1
2
3
runner = Runner(model, batch_processor, optimizer, optimizer_arch, cfg.work_dir, cfg.log_level, arch_name=arch_name)
...
runner.run(data_loaders, None, cfg.workflow, cfg.total_epochs)
这里我们来看Runner
第二个参数batch_processor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def batch_processor(model, data, train_mode, **kwargs):
losses = model(**data)
losses_ = losses[0]
loss_latency = losses[1]
if loss_latency is not None:
losses_['loss_latency'] = loss_latency
loss, log_vars = parse_losses(losses_)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs