Detector🎯2——mmdetection 数据集准备build_dataset函数详解

"mmdetection 数据集准备"

Posted by fuhao7i on January 13, 2021

mmdetection 数据集准备,包括训练数据集和测试数据集

1. 训练数据集

在train.py文件中,构建训练数据集。

1
2
3
from mmdet.datasets import build_dataset

train_dataset = build_dataset(cfg.data.train)

1.1 参数cfg.data.train

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# dataset settings
dataset_type = 'VOCDataset'
data_root ='data/VOCdevkit/VOC2007/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='Resize',
       # img_scale=(1333, 800),
        img_scale=(800,600),
        keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]


train=dict(
        type=dataset_type,
        ann_file=data_root + 'ImageSets/Main/train.txt',
        img_prefix=data_root ,
        pipeline=train_pipeline)

1.2 函数build_dataset()

其中,build_dataset函数在mmdet文件夹的datasets文件夹下的builder.py。

./mmdet/datasets/builder.py

1
2
3
4
5
6
7
8
9
10
11
12
def build_dataset(cfg, default_args=None):
    if isinstance(cfg, (list, tuple)):
        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(
            build_dataset(cfg['dataset'], default_args), cfg['times'])
    elif isinstance(cfg['ann_file'], (list, tuple)):
        dataset = _concat_dataset(cfg, default_args)
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset

这里我们会执行dataset = build_from_cfg(cfg, DATASETS, default_args). 我们来具体看一下。

1.2.1 全局变量(注册表)DATASETS的构建

mmdet/datasets/registry.py

这里将类Registry实例化,注册到注册表中。

1
2
3
4
from mmdet.utils import Registry

DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

然后我们具体看一下Registry类:

mmdet/utils/registry.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import inspect

import mmcv


class Registry(object):

    def __init__(self, name):
        self._name = name   #Registry类的名字
        self._module_dict = dict()  #创建一个模块字典

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property   #将方法修饰为类属性
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class):   #Registry类的主要方法,用来注册模块。
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls

然后我们点开mmdet/datasets下的数据集定义py文件。这里我们以voc.py数据集为例。

1.2.2 数据集定义文件 voc.py

这里我们用修饰符@修饰: @DATASETS.register_module,就是将类VOCDataset作为参数传入DATASETS.register_module()函数。

1
2
@DATASETS.register_module
class VOCDataset(XMLDataset):

1.2.2.1 父类XMLDataset

VocDataset 继承自 XMLDataset, XMLDataset 又继承自 CustomDataset, CustomDataset继承自 Dataset from torch.utils.data import Dataset.

1
2
3
4
5
class XMLDataset(CustomDataset):
	def load_annotations(self, ann_file):  # 用于初始化VocDataset的,主要初始化算imgs的位置,imgs和annotations的关联关系等
		...
	def get_ann_info(self, idx):  # 用于在train/val/test中调取annotations
		...

然后看一下在训练中时如何调用的,XMLDataset继承CustomDatset,在train过程中调用的时CustomDataset的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# ./mmdnet/datasets/custom.py

class CustomDataset(Dataset):

    def __init__(...):
        ...
        # load annotations (and proposals)
        self.img_infos = self.load_annotations(self.ann_file)
        ...
        # processing pipeline
        self.pipeline = Compose(pipeline)

    def __getitem__(self, idx):     # 重载Dataset的,就是在train过程中加载data和target的,其中用了prepare_train_img
       if self.test_mode:
           return self.prepare_test_img(idx)
       while True:
           data = self.prepare_train_img(idx)
           if data is None:
               idx = self._rand_another(idx)
               continue
           return data

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)
根据上述代码,可以看到CustomDataset中使用load_annotations初始化dataset,使用get_ann_info加载target,所以继承CustomDataset需要定义这两个函数,就可以完成自己Dataset的定义。

1.3 函数 build_from_cfg()

我们继续来看dataset = build_from_cfg(cfg, DATASETS, default_args).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_type = registry.get(obj_type)   #将cfg中cfg中的'VOCDataset'变成VOCDataset这个类赋值给obj_cls
        if obj_type is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)     #实例化VOCDataset并返回,一路返回给test.py中的dataset
pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
1
2
3
4
5
6
7
# https://blog.csdn.net/g11d111/article/details/81504637 非常好的Dataloader详解
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for img, label in dataloader:
        ...

2. 测试数据集

和训练数据集的构建相同。

tips1: isinstance(object, classinfo)函数详解:

* object -- 实例对象。
* classinfo -- 可以是直接或间接类名、基本类型或者由它们组成的元组。

classinfo可以是:int,float,bool,complex,str(字符串),list,dict(字典),set,tuple,具体的类

判断对象object的类型是否和classinfo的类型相同。相同则返回True,否则返回False。

isinstance() 与 type() 区别: type() 不会认为子类是一种父类类型,不考虑继承关系。 isinstance() 会认为子类是一种父类类型,考虑继承关系。 如果要判断两个类型是否相同推荐使用 isinstance()。

e.g

1
2
3
>>> a = 2
>>> isinstance (a,(str,int,list))    # 是元组中的一个返回 True
True